#!/usr/bin/python3

# ========================== begin_copyright_notice ============================
#
# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
#
# =========================== end_copyright_notice =============================

import argparse
import json


OUTPUT_HEADER = """// AUTOGENERATED FILE, DO NOT EDIT!
// Generated by GenerateTranslationCode.py script."""
# C++ declarations separator.
INTERVAL_BETWEEN_DECLS = "\n\n"
BUILTIN_PREFIX = "__cm_cl_"

# The name of enum with builtin IDs.
BUILTIN_ID = "BuiltinID"
# The name of the enum with operand kinds and the suffix of builtin operand
# kind arrays.
OPERAND_KIND = "OperandKind"
# The suffix of builtin operand name enums.
OPERAND_NAME = "Operand"
# The name of the variable that holds a reference to builtin call.
BUILTIN_VARIABLE = "BiCall"
# The name of the variable that holds a reference to IR builder.
IRB_VARIABLE = "IRB"

# Section names.
BUILTIN_DESCS_SECTION = "CMCL_AUTOGEN_BUILTIN_DESCS"
TRANSLATION_DESCS_SECTION = "CMCL_AUTOGEN_TRANSLATION_DESCS"
TRANSLATION_IMPL_SECTION = "CMCL_AUTOGEN_TRANSLATION_IMPL"

parser = argparse.ArgumentParser(
  description="Generate translation code from JSON description.")
parser.add_argument("--desc", required=True,
  help="JSON file with a description", metavar="<input>.json")
parser.add_argument("--output", required=True, help="output file",
  metavar="<output>.inc")

# Opens \p desc_filename JSON file and parses it.
# Parsed structures are returned.
def get_description_from_json(desc_filename):
  with open(desc_filename, "r") as desc_file:
    return json.load(desc_file)

# Generates:
# namespace name {
# enum Enum {
#   values[0],
#   values[1],
#   ...
# };
# } // namespace name
#
# The generated text is returned.
def generate_enum(name, values):
  text = "namespace {n} {{\nenum Enum {{\n".format(n=name)
  text += ",\n".join(["  {v}".format(v=value) for value in values])
  return text + "\n}};\n}} // namespace {n}".format(n=name)

# Generates:
# constexpr c_type name[] = {
#   values[0],
#   values[1],
#   ...
# };
#
# The generated text is returned.
def generate_array(c_type, name, values):
  assert values, "cannot generate an empty array"
  text = "constexpr {t} {n}[] = {{\n".format(t=c_type, n=name)
  text += ",\n".join(['  {v}'.format(v=value) for value in values])
  return text + "\n};"

# Generate enumerations that are not describing builtins but values of which
# are used to describe builtins.
def generate_helper_enums(helper_structures):
  return INTERVAL_BETWEEN_DECLS.join(
    [generate_enum(struct, helper_structures[struct])
      for struct in helper_structures])

def validate_builtin_desc(builtin_name, desc, helper_structures):
  if not all(operand["Kind"] in helper_structures[OPERAND_KIND]
               for operand in desc["Operands"]):
    raise RuntimeError("Some of {b} operand kinds is illegal because it's not "
                       "presented in OperandKind list".format(b=builtin_name))

# Raises an exception when some description inconsistency is found.
def validate_description(builtin_descs, helper_structures):
  for item in builtin_descs.items():
    validate_builtin_desc(*item, helper_structures)

# Returns a new list with additional "Size" element at the back.
def append_size(lst):
  return [*lst, "Size"]

# Generates an array with all the builtin names:
# costexpr const char* BuiltinNames[] = {
#   "__cm_cl_builtin0",
#   "__cm_cl_builtin1",
#   ...
# };
def generate_builtin_names_array(builtin_descs):
  return generate_array("const char*", "BuiltinNames",
                        ['"' + BUILTIN_PREFIX + desc["Name"] + '"'
                         for desc in builtin_descs.values()])

# Generates:
# namespace BuiltinOperand {
# enum Enum {
#   OperandName0,
#   OperandName1,
#   ...
# };
# } // namespace BuiltinOperand
def generate_operand_names_enum(builtin, desc):
  return generate_enum(
    builtin + OPERAND_NAME,
    append_size(operand["Name"] for operand in desc["Operands"]))

# Generates an enum for every builtin with its operands names to later use them
# as indices.
# Simplified output:
# enum Builtin0Operand { SRC };
# enum Builtin1Operand { DST, SRC };
# ...
def generate_operand_names_enums(builtin_descs):
  return INTERVAL_BETWEEN_DECLS.join(
    [generate_operand_names_enum(*builtin)
     for builtin in builtin_descs.items()])

# Generates an array with the number of operands for every builtin:
# constexpr int BuiltinOperandSize[] = {
#   Builtin0Operand::Size,
#   Builtin1Operand::Size,
#   ...
# };
def generate_operand_size_array(builtin_descs):
  return generate_array("int", "BuiltinOperandSize",
                        [builtin + OPERAND_NAME + "::Size"
                         for builtin in builtin_descs])

# Generates:
# constexpr OperandKind::Enum BuiltinOperandKind[] = {
#   OperandKind::Kind0,
#   OperandKind::Kind1,
#   ...
# };
def generate_operand_kinds_array(builtin, desc):
  return generate_array(OPERAND_KIND + "::Enum", builtin + OPERAND_KIND,
                        [OPERAND_KIND + "::" + operand["Kind"]
                         for operand in desc["Operands"]])

# Generates an array for every builtin with the list its operand kinds.
# Simplified output:
# constexpr OperandKind::Enum Builtin0OperandKind[] = {OperandKind::VectorIn};
# constexpr OperandKind::Enum Builtin1OperandKind[] = {
#   OperandKind::VectorOut, OperandKind::VectorIn};
def generate_operand_kinds_arrays(builtin_descs):
  return INTERVAL_BETWEEN_DECLS.join(
    generate_operand_kinds_array(builtin, desc)
      for builtin, desc in builtin_descs.items()
      if desc["Operands"])

# If there's an array of operand kinds, returns its name (array name degrades to
# pointer), otherwise returns nullptr. The can be operand kinds array if the
# builtin has no operands.
def get_operand_kinds_array_pointer(builtin, desc):
  if desc["Operands"]:
    return builtin + OPERAND_KIND
  return "nullptr"

# Generate an array of pointers to operand kinds arrays. So to get a kind of
# BuiltinN's M-th operand one can write BuiltinOperandKind[BuiltinN][M].
# Output:
# constexpr const OperandKind::Enum* BuiltinOperandKind[] = {
#   Builtin0OperandKind,
#   Builtin1OperandKind,
#   nullptr,
#   ...
# };
def generate_combined_operand_kinds_array(builtin_descs):
  return generate_array("const " + OPERAND_KIND + "::Enum*",
                        "Builtin" + OPERAND_KIND,
                        [get_operand_kinds_array_pointer(*builtin)
                         for builtin in builtin_descs.items()])

# Generate enums and arrays that describe CMCL builtins.
def generate_builtin_descriptions(builtin_descs):
  decls = [generate_enum(BUILTIN_ID, append_size(builtin_descs.keys())),
           generate_builtin_names_array(builtin_descs),
           generate_operand_names_enums(builtin_descs),
           generate_operand_size_array(builtin_descs),
           generate_operand_kinds_arrays(builtin_descs),
           generate_combined_operand_kinds_array(builtin_descs)]
  return INTERVAL_BETWEEN_DECLS.join(decls)

def begin_section(section_name):
  return "#ifdef " + section_name

def end_section(section_name):
  return "#endif // " + section_name

# Takes a list of section content and section name and returns a new list
# with section openning and closing strings at the begin and the end of the
# list.
def frame_section(section_content, section_name):
  return [begin_section(section_name),
          *section_content,
          end_section(section_name)]

# Generate a section of the output file that describes builtins.
def generate_builtin_descs_section(whole_desc):
  fragments = [generate_helper_enums(whole_desc["HelperStructures"]),
               generate_builtin_descriptions(whole_desc["BuiltinDescriptions"])]
  fragments = frame_section(fragments, BUILTIN_DESCS_SECTION)
  return INTERVAL_BETWEEN_DECLS.join(fragments)

# Generates an array of builtin handlers.
# Output:
# constexpr BuiltinCallHandler BuiltinCallHandlers[] = {
#   handleBuiltinCall<BuiltinID::Builtin0>,
#   handleBuiltinCall<BuiltinID::Builtin1>,
#   ...
# };
def generate_handlers_array(builtin_descs):
  return generate_array("BuiltinCallHandler", "BuiltinCallHandlers",
                        ["handleBuiltinCall<" + BUILTIN_ID + "::" +
                           builtin + ">"
                         for builtin in builtin_descs.keys()])

# Returns text representation for ID of intrinsic that was mentioned in
# "TranslateInto" section of builtin description. The section is taken as the
# argument. If there's no intrinsic, "~0u" is returned.
def get_intrinsic_id(translation_desc):
  if "VC-Intrinsic" in translation_desc:
    return "GenXIntrinsic::" + translation_desc["VC-Intrinsic"]
  if "LLVM-Intrinsic" in translation_desc:
    return "Intrinsic::" + translation_desc["LLVM-Intrinsic"]
  return "~0u"

# Generates an array that represents map between builtin ID and intrinsic ID.
# Output:
# constexpr unsigned IntrinsicForBuiltin[] = {
#   IntrinsicForBuiltin0,
#   IntrinsicForBuiltin1,
#   ~0u, // Builtin2 has no corresponding intrinsic
#   ...
# };
def generate_intrinsics_array(builtin_descs):
  return generate_array("unsigned", "IntrinsicForBuiltin",
                        [get_intrinsic_id(desc["TranslateInto"])
                         for desc in builtin_descs.values()])

# Generate a section of the output file that holds some structures needed for
# the translation.
def generate_translation_descs_section(builtin_descs):
  fragments = [generate_handlers_array(builtin_descs),
               generate_intrinsics_array(builtin_descs)]
  fragments = frame_section(fragments, TRANSLATION_DESCS_SECTION)
  return INTERVAL_BETWEEN_DECLS.join(fragments)

# Generates code for GetBuiltinReturnType node.
# This node must have no arguments, so \p args is passed just to assert it.
def generate_builtin_return_type_expression(builtin_name, args):
  if args:
    raise RuntimeError("Builtin {bi} has invalid expession tree description: "
                       "GetBuiltinReturnType node must no arguments.".format(
                         bi=builtin_name))
  return "*{}.getType()".format(BUILTIN_VARIABLE)

# Get single operand from \p function_name node with required additional
# checks.
def get_single_operand_from_expression(function_name, builtin_name, args,
                                       builtin_desc):
  if len(args) != 1:
    raise RuntimeError("Builtin {bi} has invalid expession tree description: "
                       "{func} node must have only one argument.".format(
                         bi=builtin_name, func=function_name))
  operand = args[0]
  builtin_operands = [op["Name"] for op in builtin_desc["Operands"]]
  if not operand in builtin_operands:
    raise RuntimeError("Builtin {bi} has invalid expession tree description: "
                       "{func} argument is not an operand of this builtin."
                       .format(bi=builtin_name, func=function_name))
  return operand

# Generates code for GetBuiltinOperandType node.
# The node must have only one argument with the operand name.
# A call to getTypeFromBuiltinOperand is generated.
def generate_builtin_operand_type_expression(builtin_name, args, builtin_desc):
  operand = get_single_operand_from_expression("GetBuiltinOperandType",
                                               builtin_name,
                                               args,
                                               builtin_desc)
  return "getTypeFromBuiltinOperand<{bi_id}::{bi}>("\
         "{bi_call}, {bi}{op_suffix}::{op})".format(bi_id=BUILTIN_ID,
                                                    bi=builtin_name,
                                                    bi_call=BUILTIN_VARIABLE,
                                                    op_suffix=OPERAND_NAME,
                                                    op=operand)

# Generates a code that returns builtin operand value as llvm::Value& from
# GetBuiltinOperand node (or better to say leaf). The node must have a single
# argument which is a builtin operand name.
def generate_builtin_operand_expression(builtin_name, args, builtin_desc):
  operand = get_single_operand_from_expression("GetBuiltinOperand",
                                               builtin_name,
                                               args,
                                               builtin_desc)
  return "readValueFromBuiltinOp<{bi_id}::{bi}>({bi_call}, "\
         "{bi}{op_suffix}::{op}, {irb_var})".format(bi_id=BUILTIN_ID,
                                                    bi=builtin_name,
                                                    bi_call=BUILTIN_VARIABLE,
                                                    op_suffix=OPERAND_NAME,
                                                    op=operand,
                                                    irb_var=IRB_VARIABLE)

# Generate code for Code node. If Code node has a single argument, it is a
# string and it is the generated code. If the node has multiple arguments, the
# first argument is a python format string, the rest of arguments are the
# arguments for this string. Arguments must be expression trees too.
def generate_code_expression(builtin_name, args, builtin_desc):
  if not args:
    raise RuntimeError("Builtin {bi} has invalid expession tree description: "
                       "Code node must have at least one argument.".format(
                         bi=builtin_name))
  code = args[0]
  if not isinstance(code, str):
    raise RuntimeError("Builtin {bi} has invalid expession tree description: "
                       "Code node must have a string as the first argument."
                       .format(bi=builtin_name))
  if len(args) == 1:
    return code
  replacements = [generate_expression_tree(builtin_name, arg, builtin_desc)
                  for arg in args[1:]]
  code = code.format(*replacements)
  return code

# Generate a complex nested expression based on description in \p desc_tree.
# Each tree node has the following structure: {name: [args]} where name is the
# name of the node, args - the list of node aruments. Args may contain nodes,
# thus the whole structure is a tree.
# \p builtin_desc is passed for validation.
def generate_expression_tree(builtin_name, tree_desc, builtin_desc):
  if len(tree_desc) != 1:
    raise RuntimeError("Builtin {bi} has invalid expession tree description: "
                       "Object must have only one entry.".format(
                         bi=builtin_name))
  function = next(iter(tree_desc))
  args = tree_desc[function]
  if not isinstance(args, list):
    raise RuntimeError("Builtin {bi} has invalid expession tree description: "
                       "Object entry item must be an Array.".format(
                         bi=builtin_name))
  if function == "GetBuiltinReturnType":
    return generate_builtin_return_type_expression(builtin_name, args)
  if function == "GetBuiltinOperandType":
    return generate_builtin_operand_type_expression(builtin_name, args,
                                                    builtin_desc)
  if function == "GetBuiltinOperand":
    return generate_builtin_operand_expression(builtin_name, args,
                                               builtin_desc)
  if function == "Code":
    return generate_code_expression(builtin_name, args, builtin_desc)
  raise RuntimeError("Builtin {bi} has invalid expession tree description: "
                     "Unknown node.".format(bi=builtin_name))

# Output:
# template <>
# Type &getTranslatedBuiltinType<BuiltinID::Builtin>(CallInst &BiCall) {
#   return .....;
# };
def generate_return_type_specialization(builtin_name, desc):
  text = "template <>\n"
  text += "Type &getTranslatedBuiltinType<{bi_enum}::{bi}>(CallInst &"\
          "{bi_var}) {{\n".format(bi_enum=BUILTIN_ID, bi=builtin_name,
                                  bi_var=BUILTIN_VARIABLE)
  text += "  return {};\n".format(
      generate_expression_tree(builtin_name,
                               desc["TranslateInto"]["ReturnType"],
                               desc))
  text += "}"
  return text

# Generates getTranslatedBuiltinOperands specialization for builtin with the
# name \p builtin_name. The rule for every operand generation is defined as
# expression tree in TranslateInto:Operands section of JSON description.
#
# Output:
# template <>
# std::vector<Value *>
# getTranslatedBuiltinOperands<BuiltinID::Builtin>(CallInst &BiCall,
#                                                  IRBuilder<> &IRB) {
#   return {.....};
# }
def generate_operand_specialization(builtin_name, desc):
  text = "template <>\n"
  text += "std::vector<Value *>\n"
  text += "getTranslatedBuiltinOperands<{bi_enum}::{bi}>(CallInst &{bi_var}, "\
          "IRBuilder<> &{irb_var}) {{\n".format(bi_enum=BUILTIN_ID,
                                                bi=builtin_name,
                                                bi_var=BUILTIN_VARIABLE,
                                                irb_var=IRB_VARIABLE)
  operands = [generate_expression_tree(builtin_name, op_desc, desc)
              for op_desc in desc["TranslateInto"]["Operands"]]
  operands = ["&" + op for op in operands]
  operands = ",\n          ".join(operands)
  text += "  return {{{}}};\n".format(operands)
  text += "}"
  return text

# Generates getTranslatedBuiltinType specialization for every builtin.
def generate_return_type_function(builtin_descs):
  return INTERVAL_BETWEEN_DECLS.join(
      [generate_return_type_specialization(*item)
       for item in builtin_descs.items()])

# Generates getTranslatedBuiltinOperands specialization for every builtin.
def generate_operand_function(builtin_descs):
  return INTERVAL_BETWEEN_DECLS.join(
      [generate_operand_specialization(*item)
       for item in builtin_descs.items()])

# Generates a section that will hold some implementation required for
# builtin translation.
def generate_translation_impl_section(builtin_descs):
  fragments = [generate_return_type_function(builtin_descs),
               generate_operand_function(builtin_descs)]
  fragments = frame_section(fragments, TRANSLATION_IMPL_SECTION)
  return INTERVAL_BETWEEN_DECLS.join(fragments)

# Generate output file text.
def generate_file(whole_desc):
  validate_description(whole_desc["BuiltinDescriptions"],
                       whole_desc["HelperStructures"])
  sections = [OUTPUT_HEADER,
              generate_builtin_descs_section(whole_desc),
              generate_translation_descs_section(whole_desc["BuiltinDescriptions"]),
              generate_translation_impl_section(whole_desc["BuiltinDescriptions"])]
  return INTERVAL_BETWEEN_DECLS.join(sections)

args = parser.parse_args()
whole_desc = get_description_from_json(args.desc)
output_str = generate_file(whole_desc)
with open(args.output, "w") as output_file:
  output_file.write(output_str)
