# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Codegen for FooJni.java files."""

import common
import java_types
import proxy


class _Context:

  def __init__(self, jni_obj, gen_jni_class, script_name, is_per_file):
    self.jni_obj = jni_obj
    self.gen_jni_class = gen_jni_class
    self.script_name = script_name
    self.is_per_file = is_per_file

    self.interface_name = jni_obj.proxy_interface.name_with_dots
    self.proxy_class = java_types.JavaClass(
        f'{self.jni_obj.java_class.full_name_with_slashes}Jni')
    self.type_resolver = java_types.TypeResolver(self.proxy_class)
    imports = jni_obj.GetClassesToBeImported() + [
        java_types.JavaClass('org/jni_zero/CheckDiscard'),
        java_types.JavaClass('org/jni_zero/JniTestInstanceHolder'),
        java_types.JavaClass('org/jni_zero/internal/NullUnmarked'),
        java_types.JavaClass('org/jni_zero/internal/Nullable'),
    ]
    if not is_per_file:
      imports.append(gen_jni_class)
    self.type_resolver.imports = imports


def _implicit_array_class_param(native, type_resolver):
  return_type = native.return_type
  class_name = return_type.to_array_element_type().to_java(type_resolver)
  return class_name + '.class'


def _proxy_method(sb, ctx, native, method_fqn):
  return_type_str = native.return_type.to_java(ctx.type_resolver)
  sig_params = native.params.to_java_declaration(ctx.type_resolver)

  sb(f"""
@Override
public {return_type_str} {native.name}({sig_params})""")
  with sb.block():
    if native.first_param_cpp_type:
      sb(f'assert {native.params[0].name} != 0;\n')
    for p in native.params:
      if not p.java_type.is_primitive() and not p.java_type.nullable:
        sb(f'assert {p.name} != null : "Parameter \\"{p.name}\\" was null. Add @Nullable to it?";\n')
    with sb.statement():
      if not native.return_type.is_void():
        sb(f'return ({return_type_str}) ')
      sb(method_fqn)
      with sb.param_list() as plist:
        plist.extend(p.name for p in native.params)
        if native.needs_implicit_array_element_class_param:
          plist.append(_implicit_array_class_param(native, ctx.type_resolver))


def _native_method(sb, ctx, native, name):
  sig_params = native.proxy_params.to_java_declaration(ctx.type_resolver)
  return_type = native.proxy_return_type.to_java()
  sb(f'private static native {return_type} {name}({sig_params});\n')


def _class_body(sb, ctx):
  sb(f"""\
private static @Nullable JniTestInstanceHolder sOverride;

public static {ctx.interface_name} get() {{
  JniTestInstanceHolder holder = sOverride;
  if (holder != null && holder.value != null) {{
    return ({ctx.interface_name}) holder.value;
  }}
  return new {ctx.proxy_class.name}();
}}

public static void setInstanceForTesting({ctx.interface_name} impl) {{
  if (sOverride == null) {{
    sOverride = JniTestInstanceHolder.create();
  }}
  sOverride.value = impl;
}}

""")

  for native in ctx.jni_obj.proxy_natives:
    if ctx.is_per_file:
      method_fqn = native.per_file_name
      _native_method(sb, ctx, native, method_fqn)
    else:
      method_fqn = f'{ctx.gen_jni_class.name}.{native.proxy_name}'

    _proxy_method(sb, ctx, native, method_fqn)


def _imports(sb, ctx):
  classes = set()
  for c in ctx.type_resolver.imports:
    # Since this is Java, the class generated here will go through jarjar
    # and thus we want to avoid prefixes (with the exception of GEN_JNI).
    c = c if c is ctx.gen_jni_class else c.class_without_prefix
    if c.is_nested:
      # We will refer to all nested classes by OuterClass.InnerClass. We do this
      # to reduce risk of naming collisions.
      c = c.get_outer_class()
    classes.add(c.full_name_with_dots)

  for c in sorted(classes):
    sb(f'import {c};\n')


def Generate(jni_mode, jni_obj, *, gen_jni_class, script_name):
  ctx = _Context(jni_obj, gen_jni_class, script_name, jni_mode.is_per_file)

  sb = common.StringBuilder()
  sb(f"""\
//
// This file was generated by {script_name}
//
package {jni_obj.java_class.class_without_prefix.package_with_dots};

""")
  _imports(sb, ctx)
  sb('\n')

  visibility = 'public ' if jni_obj.proxy_visibility == 'public' else ''
  class_name = ctx.proxy_class.name
  if not ctx.is_per_file:
    sb('@CheckDiscard("crbug.com/993421")\n')
  sb('@NullUnmarked\n')
  sb(f'{visibility}class {class_name} implements {ctx.interface_name}')
  with sb.block():
    _class_body(sb, ctx)
  return sb.to_string()
