#!/usr/bin/python3 -i
#
# Copyright (c) 2015-2022 The Khronos Group Inc.
# Copyright (c) 2015-2022 Valve Corporation
# Copyright (c) 2015-2022 LunarG, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Mark Lobodzinski <mark@lunarg.com>
# Author: Nadav Geva <nadav.geva@amd.com>

import os,re,sys,string,json
import xml.etree.ElementTree as etree
from generator import *
from collections import namedtuple
from common_codegen import *

# This is a workaround to use a Python 2.7 and 3.x compatible syntax
from io import open

class BestPracticesOutputGeneratorOptions(GeneratorOptions):
    def __init__(self,
                 conventions = None,
                 filename = None,
                 directory = '.',
                 genpath = None,
                 apiname = 'vulkan',
                 profile = None,
                 versions = '.*',
                 emitversions = '.*',
                 defaultExtensions = 'vulkan',
                 addExtensions = None,
                 removeExtensions = None,
                 emitExtensions = None,
                 emitSpirv = None,
                 sortProcedure = regSortFeatures,
                 genFuncPointers = True,
                 protectFile = True,
                 protectFeature = False,
                 apicall = 'VKAPI_ATTR ',
                 apientry = 'VKAPI_CALL ',
                 apientryp = 'VKAPI_PTR *',
                 indentFuncProto = True,
                 indentFuncPointer = False,
                 alignFuncParam = 48,
                 expandEnumerants = False):
        GeneratorOptions.__init__(self,
                conventions = conventions,
                filename = filename,
                directory = directory,
                genpath = genpath,
                apiname = apiname,
                profile = profile,
                versions = versions,
                emitversions = emitversions,
                defaultExtensions = defaultExtensions,
                addExtensions = addExtensions,
                removeExtensions = removeExtensions,
                emitExtensions = emitExtensions,
                emitSpirv = emitSpirv,
                sortProcedure = sortProcedure)
        self.genFuncPointers = genFuncPointers
        self.protectFile     = protectFile
        self.protectFeature  = protectFeature
        self.apicall         = apicall
        self.apientry        = apientry
        self.apientryp       = apientryp
        self.indentFuncProto = indentFuncProto
        self.indentFuncPointer = indentFuncPointer
        self.alignFuncParam  = alignFuncParam
        self.expandEnumerants = expandEnumerants
#
# BestPracticesOutputGenerator(errFile, warnFile, diagFile)
class BestPracticesOutputGenerator(OutputGenerator):
    def __init__(self,
                 errFile = sys.stderr,
                 warnFile = sys.stderr,
                 diagFile = sys.stdout):
        OutputGenerator.__init__(self, errFile, warnFile, diagFile)
        # Commands which are not autogenerated but still intercepted
        self.no_autogen_list = [
            'vkEnumerateInstanceVersion',
            'vkCreateValidationCacheEXT',
            'vkDestroyValidationCacheEXT',
            'vkMergeValidationCachesEXT',
            'vkGetValidationCacheDataEXT',
            ]
        # Commands that require an extra parameter for state sharing between validate/record steps
        self.extra_parameter_list = [
            "vkCreateShaderModule",
            "vkCreateGraphicsPipelines",
            "vkCreateComputePipelines",
            "vkAllocateDescriptorSets",
            "vkCreateRayTracingPipelinesNV",
            "vkCreateRayTracingPipelinesKHR",
            ]
        # Commands that have a manually written post-call-record step which needs to be called from the autogen'd fcn
        self.manual_postcallrecord_list = [
            'vkAllocateDescriptorSets',
            'vkAllocateMemory',
            'vkQueuePresentKHR',
            'vkQueueBindSparse',
            'vkCreateGraphicsPipelines',
            'vkGetPhysicalDeviceSurfaceCapabilitiesKHR',
            'vkGetPhysicalDeviceSurfaceCapabilities2KHR',
            'vkGetPhysicalDeviceSurfaceCapabilities2EXT',
            'vkGetPhysicalDeviceSurfacePresentModesKHR',
            'vkGetPhysicalDeviceSurfaceFormatsKHR',
            'vkGetPhysicalDeviceSurfaceFormats2KHR',
            'vkGetPhysicalDeviceDisplayPlanePropertiesKHR',
            'vkGetSwapchainImagesKHR',
            # AMD tracked
            'vkCreateComputePipelines',
            'vkCmdPipelineBarrier',
            'vkQueueSubmit',
            ]

        self.extension_info = dict()
    #
    # Separate content for validation source and header files
    def otwrite(self, dest, formatstring):
        if 'best_practices.h' in self.genOpts.filename and (dest == 'hdr' or dest == 'both'):
            write(formatstring, file=self.outFile)
        elif 'best_practices.cpp' in self.genOpts.filename and (dest == 'cpp' or dest == 'both'):
            write(formatstring, file=self.outFile)
    #
    # Called at beginning of processing as file is opened
    def beginFile(self, genOpts):
        OutputGenerator.beginFile(self, genOpts)

        header_file = (genOpts.filename == 'best_practices.h')
        source_file = (genOpts.filename == 'best_practices.cpp')

        if not header_file and not source_file:
            print("Error: Output Filenames have changed, update generator source.\n")
            sys.exit(1)

        # File Comment
        file_comment = '// *** THIS FILE IS GENERATED - DO NOT EDIT ***\n'
        file_comment += '// See best_practices_generator.py for modifications\n'
        self.otwrite('both', file_comment)
        # Copyright Statement
        copyright = ''
        copyright += '\n'
        copyright += '/***************************************************************************\n'
        copyright += ' *\n'
        copyright += ' * Copyright (c) 2015-2022 The Khronos Group Inc.\n'
        copyright += ' * Copyright (c) 2015-2022 Valve Corporation\n'
        copyright += ' * Copyright (c) 2015-2022 LunarG, Inc.\n'
        copyright += ' *\n'
        copyright += ' * Licensed under the Apache License, Version 2.0 (the "License");\n'
        copyright += ' * you may not use this file except in compliance with the License.\n'
        copyright += ' * You may obtain a copy of the License at\n'
        copyright += ' *\n'
        copyright += ' *     http://www.apache.org/licenses/LICENSE-2.0\n'
        copyright += ' *\n'
        copyright += ' * Unless required by applicable law or agreed to in writing, software\n'
        copyright += ' * distributed under the License is distributed on an "AS IS" BASIS,\n'
        copyright += ' * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n'
        copyright += ' * See the License for the specific language governing permissions and\n'
        copyright += ' * limitations under the License.\n'
        copyright += ' *\n'
        copyright += ' * Author: Mark Lobodzinski <mark@lunarg.com>\n'
        copyright += ' * Author: Nadav Geva <nadav.geva@amd.com>\n'
        copyright += ' *\n'
        copyright += ' ****************************************************************************/\n'
        self.otwrite('both', copyright)
        self.newline()
        self.otwrite('cpp', '#include "chassis.h"')
        self.otwrite('cpp', '#include "best_practices_validation.h"')
    #
    # Now that the data is all collected and complete, generate and output the object validation routines
    def endFile(self):
        self.newline()
        # Actually write the interface to the output file.
        if (self.emit):
            self.newline()
            if self.featureExtraProtect is not None:
                prot = '#ifdef %s' % self.featureExtraProtect
                self.otwrite('both', '%s' % prot)
            if (self.featureExtraProtect is not None):
                prot = '\n#endif // %s', self.featureExtraProtect
                self.otwrite('both', prot)
            else:
                self.otwrite('both', '\n')

            # Output data structure containing extension deprecation data
            ext_deprecation_data = 'const layer_data::unordered_map<std::string, DeprecationData>  deprecated_extensions = {\n'
            for ext in sorted(self.extension_info):
                ext_data = self.extension_info[ext]
                reason = ext_data[0]
                target = ext_data[1]
                if reason is not None:
                    ext_deprecation_data += '    {"%s", {kExt%s, "%s"}},\n' % (ext, reason, target)
            ext_deprecation_data += '};\n'
            self.otwrite('hdr', ext_deprecation_data)

            # Output data structure containing extension special use data
            ext_specialuse_data = 'const layer_data::unordered_map<std::string, std::string> special_use_extensions = {\n'
            for ext in sorted(self.extension_info):
                spec_use_data = self.extension_info[ext]
                special_uses = spec_use_data[2]
                if special_uses is not None:
                    special_uses = special_uses.replace(',', ', ')
                    ext_specialuse_data += '    {"%s", "%s"},\n' % (ext, special_uses)
            ext_specialuse_data += '};\n'
            self.otwrite('hdr', ext_specialuse_data)

        OutputGenerator.endFile(self)
    #
    # Processing point at beginning of each extension definition
    def beginFeature(self, interface, emit):
        OutputGenerator.beginFeature(self, interface, emit)
        self.featureExtraProtect = GetFeatureProtect(interface)
        ext_name = interface.attrib.get('name')
        ext_special_use = interface.attrib.get('specialuse')
        ext_promoted = (interface.attrib.get('promotedto'))
        ext_obsoleted = interface.attrib.get('obsoletedby')
        ext_deprecated = interface.attrib.get('deprecatedby')
        if ext_promoted is not None:
           reason = 'Promoted'
           target = ext_promoted
        elif ext_obsoleted is not None:
           reason = 'Obsoleted'
           target = ext_obsoleted
        elif ext_deprecated is not None:
           reason = 'Deprecated'
           target = ext_deprecated
        else:
            reason = None
            target = None
        if reason is not None or ext_special_use is not None:
            self.extension_info[ext_name] = [reason, target, ext_special_use]

    #
    # Retrieve the type and name for a parameter
    def getTypeNameTuple(self, param):
        type = ''
        name = ''
        for elem in param:
            if elem.tag == 'type':
                type = noneStr(elem.text)
            elif elem.tag == 'name':
                name = noneStr(elem.text)
        return (type, name)
    #
    # Capture command parameter info needed to create, destroy, and validate objects
    def genCmd(self, cmdinfo, cmdname, alias):
        OutputGenerator.genCmd(self, cmdinfo, cmdname, alias)
        intercept = ''
        if cmdname in self.no_autogen_list:
            intercept += '// Skipping %s for autogen as it has a manually created custom function or ignored.\n' % cmdname
            self.otwrite('cpp', intercept)
            return
        cdecl=self.makeCDecls(cmdinfo.elem)[0]
        decls = self.makeCDecls(cmdinfo.elem)
        typedef = decls[1]
        typedef = typedef.split(')',1)[1]
        pre_decl = decls[0][:-1]
        pre_decl = pre_decl.split("VKAPI_CALL ")[1]
        pre_decl = pre_decl.replace(')', ',\n    VkResult                                    result)')
        if cmdname in self.extra_parameter_list:
            pre_decl = pre_decl.replace(')', ',\n    void*                                       state_data)')
        pre_decl = pre_decl.replace(')', ') {\n')
        pre_decl = 'void BestPractices::PostCallRecord' + pre_decl[2:]
        type = cdecl.split(' ')[1]
        if type == 'VkResult':
            error_codes = cmdinfo.elem.attrib.get('errorcodes')
            success_codes = cmdinfo.elem.attrib.get('successcodes')
            success_codes = success_codes.replace('VK_SUCCESS,','')
            success_codes = success_codes.replace('VK_SUCCESS','')
            # Treat empty string as 'None' for consistency with 'error_codes'
            if success_codes == '':
                success_codes = None
            if error_codes is None and success_codes is None:
                return
            if self.featureExtraProtect is not None:
                self.otwrite('both', '#ifdef %s\n' % self.featureExtraProtect)
            func_decl = pre_decl.replace(' {',' override;\n');
            func_decl = func_decl.replace('BestPractices::', '')
            self.otwrite('hdr', func_decl)
            intercept += pre_decl
            params_text = ''
            params = cmdinfo.elem.findall('param')
            for param in params:
                paramtype,paramname = self.getTypeNameTuple(param)
                params_text += '%s, ' % paramname
            params_text = params_text + 'result, '
            if cmdname in self.extra_parameter_list:
                params_text += 'state_data, '
            params_text = params_text[:-2] + ');\n'
            intercept += '    ValidationStateTracker::PostCallRecord'+cmdname[2:] + '(' + params_text
            if cmdname in self.manual_postcallrecord_list:
                intercept += '    ManualPostCallRecord'+cmdname[2:] + '(' + params_text
            intercept += '    if (result != VK_SUCCESS) {\n'
            error_input   = '{}'
            success_input = '{}'
            if error_codes is not None:
                intercept += '        constexpr std::array error_codes = {%s};\n' % error_codes
                error_input = 'error_codes'
            if success_codes is not None:
                intercept += '        constexpr std::array success_codes = {%s};\n' % success_codes
                success_input = 'success_codes'
            intercept += '        ValidateReturnCodes("%s", result, %s, %s);\n' % (cmdname, error_input, success_input)
            intercept += '    }\n'
            intercept += '}\n'
            self.otwrite('cpp', intercept)
            if self.featureExtraProtect is not None:
                self.otwrite('both', '#endif // %s\n' % self.featureExtraProtect)
