• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""Codegen for calling C++ methods from Java."""
5
6from codegen import convert_type
7from codegen import header_common
8import common
9
10
11def _return_type_cpp(java_type):
12  if converted_type := java_type.converted_type:
13    return converted_type
14  if java_type.is_primitive():
15    return java_type.to_cpp()
16  return f'jni_zero::ScopedJavaLocalRef<{java_type.to_cpp()}>'
17
18
19def _param_type_cpp(java_type):
20  if converted_type := java_type.converted_type:
21    # Drop & when the type is obviously a pointer to avoid "const char *&".
22    if not java_type.is_primitive() and not converted_type.endswith('*'):
23      converted_type += '&'
24    return converted_type
25
26  ret = java_type.to_cpp()
27  if java_type.is_primitive():
28    return ret
29  return f'const jni_zero::JavaParamRef<{ret}>&'
30
31
32def _impl_forward_declaration(sb, native, params):
33  sb('// Forward declaration. To be implemented by the including .cc file.\n')
34  with sb.statement():
35    name = f'JNI_{native.java_class.name}_{native.capitalized_name}'
36    sb(f'static {_return_type_cpp(native.return_type)} {name}')
37    with sb.param_list() as plist:
38      plist.append('JNIEnv* env')
39      if not native.static:
40        plist.append('const jni_zero::JavaParamRef<jobject>& jcaller')
41      plist.extend(f'{_param_type_cpp(p.java_type)} {p.cpp_name()}'
42                   for p in params)
43
44
45def _prep_param(sb, is_proxy, param):
46  """Returns the snippet to use for the parameter."""
47  orig_name = param.cpp_name()
48  java_type = param.java_type
49
50  if java_type.converted_type:
51    ret = f'{param.name}_converted'
52    with sb.statement():
53      sb(f'{java_type.converted_type} {ret} = ')
54      convert_type.from_jni_expression(sb, orig_name, java_type)
55    return ret
56
57  if java_type.is_primitive():
58    return orig_name
59
60  if is_proxy and java_type.to_cpp() != java_type.to_proxy().to_cpp():
61    # E.g. jobject -> jstring
62    orig_name = f'static_cast<{java_type.to_cpp()}>({orig_name})'
63  return f'jni_zero::JavaParamRef<{java_type.to_cpp()}>(env, {orig_name})'
64
65
66def entry_point_declaration(sb, jni_mode, jni_obj, native, gen_jni_class):
67  """The method called by JNI, or by multiplexing methods."""
68  if jni_mode.is_muxing and native.is_proxy:
69    # In this case, it's not the symbol that JNI resolves, but the one the
70    # switch table jumps to.
71    function_name = native.muxed_entry_point_name
72    define = 'JNI_ZERO_MUXED_ENTRYPOINT'
73  else:
74    function_name = native.boundary_name_cpp(jni_mode,
75                                             gen_jni_class=gen_jni_class)
76    define = 'JNI_ZERO_BOUNDARY_EXPORT'
77  return_type_cpp = native.entry_point_return_type.to_cpp()
78  params = native.entry_point_params(jni_mode)
79  sb(f'{define} {return_type_cpp} {function_name}')
80  with sb.param_list() as plist:
81    plist.append('JNIEnv* env')
82    if not jni_mode.is_muxing:
83      # The jclass param is never used, so do not bother adding it since muxed
84      # entry points are not boundary (JNI) methods.
85      jtype = 'jclass' if native.static else 'jobject'
86      plist.append(f'{jtype} jcaller')
87    plist.extend(f'{p.java_type.to_cpp()} {p.cpp_name()}' for p in params)
88
89
90def entry_point_method(sb, jni_mode, jni_obj, native, gen_jni_class):
91  """The method called by JNI, or by multiplexing methods."""
92  params = native.params
93  cpp_class = native.first_param_cpp_type
94  if cpp_class:
95    params = params[1:]
96
97  # Only non-class methods need to be forward-declared.
98  if not cpp_class:
99    _impl_forward_declaration(sb, native, params)
100    sb('\n')
101
102  entry_point_declaration(sb, jni_mode, jni_obj, native, gen_jni_class)
103
104  entry_point_return_type = native.entry_point_return_type
105  return_type = native.return_type
106  with sb.block(after='\n'):
107    param_rvalues = [
108        _prep_param(sb, native.is_proxy, param) for param in params
109    ]
110
111    with sb.statement():
112      if not return_type.is_void():
113        sb('auto _ret = ')
114      if cpp_class:
115        sb(f'reinterpret_cast<{cpp_class}*>({native.params[0].cpp_name()})'
116           f'->{native.capitalized_name}')
117      else:
118        sb(f'JNI_{native.java_class.name}_{native.capitalized_name}')
119      with sb.param_list() as plist:
120        plist.append('env')
121        if not native.static:
122          plist.append('jni_zero::JavaParamRef<jobject>(env, jcaller)')
123        plist.extend(param_rvalues)
124
125    if return_type.is_void():
126      return
127
128    if not return_type.converted_type:
129      if return_type.is_primitive():
130        sb('return _ret;\n')
131      else:
132        # Use ReleaseLocal() to ensure we are not calling .Release() on a
133        # global ref. https://crbug.com/40944912
134        sb('return _ret.ReleaseLocal();\n')
135      return
136
137    with sb.statement():
138      sb('jobject converted_ret = ')
139      if native.needs_implicit_array_element_class_param:
140        clazz_snippet = f'static_cast<jclass>({native.proxy_params[-1].name})'
141      else:
142        clazz_snippet = None
143      convert_type.to_jni_expression(sb,
144                                     '_ret',
145                                     return_type,
146                                     clazz_snippet=clazz_snippet)
147      sb('.Release()')
148
149    with sb.statement():
150      sb('return ')
151      if entry_point_return_type.to_cpp() != 'jobject':
152        sb(f'static_cast<{entry_point_return_type.to_cpp()}>(converted_ret)')
153      else:
154        sb('converted_ret')
155
156
157def multiplexing_boundary_method(sb, muxed_aliases, gen_jni_class):
158  """The method called by JNI when multiplexing is enabled."""
159  native = muxed_aliases[0]
160  sig = native.muxed_signature
161  has_switch_num = native.muxed_switch_num != -1
162  boundary_name_cpp = native.boundary_name_cpp(common.JniMode.MUXING,
163                                               gen_jni_class=gen_jni_class)
164  sb(f'JNI_ZERO_BOUNDARY_EXPORT {sig.return_type.to_cpp()} {boundary_name_cpp}')
165  param_names = []
166  with sb.param_list() as plist:
167    plist += ['JNIEnv* env', 'jclass jcaller']
168    if has_switch_num:
169      plist.append('jint switch_num')
170    param_names += ['env']
171    for i, p in enumerate(sig.param_list):
172      plist.append(f'{p.java_type.to_cpp()} p{i}')
173      param_names.append(f'p{i}')
174
175  param_call_str = ', '.join(param_names)
176  with sb.block():
177    if not has_switch_num:
178      sb(f'return {native.muxed_entry_point_name}({param_call_str});\n')
179    else:
180      num_aliases = len(muxed_aliases)
181      sb(f'JNI_ZERO_DCHECK(switch_num >= 0 && switch_num < {num_aliases});\n')
182      sb('switch (switch_num)')
183      with sb.block():
184        for native in muxed_aliases:
185          sb(f'case {native.muxed_switch_num}:\n')
186          sb(f'  return {native.muxed_entry_point_name}({param_call_str});\n')
187        sb('default:\n')
188        sb('  __builtin_unreachable();\n')
189