# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Codegen for GEN_JNI.java."""

import common
import java_types


def _stub_for_missing_native(sb, native):
  sb(f'public static {native.proxy_return_type.to_java()} {native.proxy_name}')
  with sb.param_list() as plist:
    plist.extend(p.to_java_declaration() for p in native.proxy_params)
  with sb.block():
    sb('throw new RuntimeException("Native method not present");\n')


def _forwarding_method(sb, jni_mode, native, target_class_name):
  sb(f'public static {native.proxy_return_type.to_java()} {native.proxy_name}')
  sb.param_list([p.to_java_declaration() for p in native.proxy_params])
  with sb.block():
    with sb.statement():
      if not native.proxy_return_type.is_void():
        sb('return ')
      if jni_mode.is_muxing:
        sb(f'{target_class_name}.{native.muxed_name}')
        with sb.param_list() as plist:
          if native.muxed_switch_num != -1:
            plist.append(str(native.muxed_switch_num))
          plist.extend(p.name for p in native.muxed_params)
      else:
        sb(f'{target_class_name}.{native.hashed_name}')
        sb.param_list([p.name for p in native.proxy_params])


def _native_method(sb, jni_mode, native):
  name = native.boundary_name(jni_mode)
  if jni_mode.is_hashing:
    sb(f'// Original name: {native.java_class.full_name_with_dots}#{native.name}\n'
       )
  elif not jni_mode.is_muxing:
    sb(f'// Hashed name: Java_J_N_{native.hashed_name}\n')
  with sb.statement():
    sb(f'public static native {native.proxy_return_type.to_java()} {name}')
    with sb.param_list() as plist:
      if jni_mode.is_muxing:
        if native.muxed_switch_num != -1:
          plist.append('int _switchNum')
        for i, p in enumerate(native.muxed_params):
          plist.append(f'{p.java_type.to_java()} p{i}')
      else:
        plist.extend(p.to_java_declaration() for p in native.proxy_params)


def generate_forwarding(jni_mode, script_name, full_gen_jni_class,
                        short_gen_jni_class, present_proxy_natives,
                        absent_proxy_natives):
  """GEN_JNI.java that forwards calls to native methods on N.java."""
  sb = common.StringBuilder()
  sb(f"""\
//
// This file was generated by {script_name}
//

package {full_gen_jni_class.package_with_dots};

public class {full_gen_jni_class.name}""")
  with sb.block():
    short_class_with_dots = short_gen_jni_class.full_name_with_dots
    for native in present_proxy_natives:
      _forwarding_method(sb, jni_mode, native, short_class_with_dots)

    for native in absent_proxy_natives:
      _stub_for_missing_native(sb, native)
  return sb.to_string()


def generate_impl(jni_mode,
                  script_name,
                  gen_jni_class,
                  boundary_proxy_natives,
                  absent_proxy_natives,
                  *,
                  whole_hash=None,
                  priority_hash=None):
  """GEN_JNI.java (or N.java) that has all "public static native" methods."""
  sb = common.StringBuilder()
  sb(f"""\
//
// This file was generated by {script_name}
//

package {gen_jni_class.package_with_dots};

public class {gen_jni_class.name}""")
  with sb.block():
    if jni_mode.is_muxing:
      sb(f"""\
public static final long WHOLE_HASH = {whole_hash}L;
public static final long PRIORITY_HASH = {priority_hash}L;

""")

    for native in boundary_proxy_natives:
      _native_method(sb, jni_mode, native)
    for native in absent_proxy_natives:
      _stub_for_missing_native(sb, native)
  return sb.to_string()
