1# Copyright 2024 The Chromium Authors 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4"""Codegen for calling C++ methods from Java.""" 5 6from codegen import convert_type 7from codegen import header_common 8import common 9 10 11def _return_type_cpp(java_type): 12 if converted_type := java_type.converted_type: 13 return converted_type 14 if java_type.is_primitive(): 15 return java_type.to_cpp() 16 return f'jni_zero::ScopedJavaLocalRef<{java_type.to_cpp()}>' 17 18 19def _param_type_cpp(java_type): 20 if converted_type := java_type.converted_type: 21 # Drop & when the type is obviously a pointer to avoid "const char *&". 22 if not java_type.is_primitive() and not converted_type.endswith('*'): 23 converted_type += '&' 24 return converted_type 25 26 ret = java_type.to_cpp() 27 if java_type.is_primitive(): 28 return ret 29 return f'const jni_zero::JavaParamRef<{ret}>&' 30 31 32def _impl_forward_declaration(sb, native, params): 33 sb('// Forward declaration. To be implemented by the including .cc file.\n') 34 with sb.statement(): 35 name = f'JNI_{native.java_class.name}_{native.capitalized_name}' 36 sb(f'static {_return_type_cpp(native.return_type)} {name}') 37 with sb.param_list() as plist: 38 plist.append('JNIEnv* env') 39 if not native.static: 40 plist.append('const jni_zero::JavaParamRef<jobject>& jcaller') 41 plist.extend(f'{_param_type_cpp(p.java_type)} {p.cpp_name()}' 42 for p in params) 43 44 45def _prep_param(sb, is_proxy, param): 46 """Returns the snippet to use for the parameter.""" 47 orig_name = param.cpp_name() 48 java_type = param.java_type 49 50 if java_type.converted_type: 51 ret = f'{param.name}_converted' 52 with sb.statement(): 53 sb(f'{java_type.converted_type} {ret} = ') 54 convert_type.from_jni_expression(sb, orig_name, java_type) 55 return ret 56 57 if java_type.is_primitive(): 58 return orig_name 59 60 if is_proxy and java_type.to_cpp() != java_type.to_proxy().to_cpp(): 61 # E.g. jobject -> jstring 62 orig_name = f'static_cast<{java_type.to_cpp()}>({orig_name})' 63 return f'jni_zero::JavaParamRef<{java_type.to_cpp()}>(env, {orig_name})' 64 65 66def entry_point_declaration(sb, jni_mode, jni_obj, native, gen_jni_class): 67 """The method called by JNI, or by multiplexing methods.""" 68 if jni_mode.is_muxing and native.is_proxy: 69 # In this case, it's not the symbol that JNI resolves, but the one the 70 # switch table jumps to. 71 function_name = native.muxed_entry_point_name 72 define = 'JNI_ZERO_MUXED_ENTRYPOINT' 73 else: 74 function_name = native.boundary_name_cpp(jni_mode, 75 gen_jni_class=gen_jni_class) 76 define = 'JNI_ZERO_BOUNDARY_EXPORT' 77 return_type_cpp = native.entry_point_return_type.to_cpp() 78 params = native.entry_point_params(jni_mode) 79 sb(f'{define} {return_type_cpp} {function_name}') 80 with sb.param_list() as plist: 81 plist.append('JNIEnv* env') 82 if not jni_mode.is_muxing: 83 # The jclass param is never used, so do not bother adding it since muxed 84 # entry points are not boundary (JNI) methods. 85 jtype = 'jclass' if native.static else 'jobject' 86 plist.append(f'{jtype} jcaller') 87 plist.extend(f'{p.java_type.to_cpp()} {p.cpp_name()}' for p in params) 88 89 90def entry_point_method(sb, jni_mode, jni_obj, native, gen_jni_class): 91 """The method called by JNI, or by multiplexing methods.""" 92 params = native.params 93 cpp_class = native.first_param_cpp_type 94 if cpp_class: 95 params = params[1:] 96 97 # Only non-class methods need to be forward-declared. 98 if not cpp_class: 99 _impl_forward_declaration(sb, native, params) 100 sb('\n') 101 102 entry_point_declaration(sb, jni_mode, jni_obj, native, gen_jni_class) 103 104 entry_point_return_type = native.entry_point_return_type 105 return_type = native.return_type 106 with sb.block(after='\n'): 107 param_rvalues = [ 108 _prep_param(sb, native.is_proxy, param) for param in params 109 ] 110 111 with sb.statement(): 112 if not return_type.is_void(): 113 sb('auto _ret = ') 114 if cpp_class: 115 sb(f'reinterpret_cast<{cpp_class}*>({native.params[0].cpp_name()})' 116 f'->{native.capitalized_name}') 117 else: 118 sb(f'JNI_{native.java_class.name}_{native.capitalized_name}') 119 with sb.param_list() as plist: 120 plist.append('env') 121 if not native.static: 122 plist.append('jni_zero::JavaParamRef<jobject>(env, jcaller)') 123 plist.extend(param_rvalues) 124 125 if return_type.is_void(): 126 return 127 128 if not return_type.converted_type: 129 if return_type.is_primitive(): 130 sb('return _ret;\n') 131 else: 132 # Use ReleaseLocal() to ensure we are not calling .Release() on a 133 # global ref. https://crbug.com/40944912 134 sb('return _ret.ReleaseLocal();\n') 135 return 136 137 with sb.statement(): 138 sb('jobject converted_ret = ') 139 if native.needs_implicit_array_element_class_param: 140 clazz_snippet = f'static_cast<jclass>({native.proxy_params[-1].name})' 141 else: 142 clazz_snippet = None 143 convert_type.to_jni_expression(sb, 144 '_ret', 145 return_type, 146 clazz_snippet=clazz_snippet) 147 sb('.Release()') 148 149 with sb.statement(): 150 sb('return ') 151 if entry_point_return_type.to_cpp() != 'jobject': 152 sb(f'static_cast<{entry_point_return_type.to_cpp()}>(converted_ret)') 153 else: 154 sb('converted_ret') 155 156 157def multiplexing_boundary_method(sb, muxed_aliases, gen_jni_class): 158 """The method called by JNI when multiplexing is enabled.""" 159 native = muxed_aliases[0] 160 sig = native.muxed_signature 161 has_switch_num = native.muxed_switch_num != -1 162 boundary_name_cpp = native.boundary_name_cpp(common.JniMode.MUXING, 163 gen_jni_class=gen_jni_class) 164 sb(f'JNI_ZERO_BOUNDARY_EXPORT {sig.return_type.to_cpp()} {boundary_name_cpp}') 165 param_names = [] 166 with sb.param_list() as plist: 167 plist += ['JNIEnv* env', 'jclass jcaller'] 168 if has_switch_num: 169 plist.append('jint switch_num') 170 param_names += ['env'] 171 for i, p in enumerate(sig.param_list): 172 plist.append(f'{p.java_type.to_cpp()} p{i}') 173 param_names.append(f'p{i}') 174 175 param_call_str = ', '.join(param_names) 176 with sb.block(): 177 if not has_switch_num: 178 sb(f'return {native.muxed_entry_point_name}({param_call_str});\n') 179 else: 180 num_aliases = len(muxed_aliases) 181 sb(f'JNI_ZERO_DCHECK(switch_num >= 0 && switch_num < {num_aliases});\n') 182 sb('switch (switch_num)') 183 with sb.block(): 184 for native in muxed_aliases: 185 sb(f'case {native.muxed_switch_num}:\n') 186 sb(f' return {native.muxed_entry_point_name}({param_call_str});\n') 187 sb('default:\n') 188 sb(' __builtin_unreachable();\n') 189