// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! Code generation for the `call_method!` series of proc-macros. This module is used by creating a //! [`MethodCall`] instance and calling [`MethodCall::generate`]. The generated code will vary //! for each macro based on the contained [`Receiver`] value. If there is a detected issue with the //! [`MethodCall`] provided, then `generate` will return a [`CodegenError`] with information on //! what is wrong. This error can be converted into a [`syn::Error`] for reporting any failures to //! the user. use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned, TokenStreamExt}; use syn::{parse_quote, spanned::Spanned, Expr, Ident, ItemStatic, LitStr, Type}; use crate::type_parser::{JavaType, MethodSig, NonArray, Primitive, ReturnType}; /// The errors that can be encountered during codegen. Used in [`CodegenError`]. #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ErrorKind { InvalidArgsLength { expected: usize, found: usize }, ConstructorRetValShouldBeVoid, InvalidTypeSignature, } /// An error encountered during codegen with span information. Can be converted to [`syn::Error`]. /// using [`From`]. #[derive(Clone, Debug)] pub struct CodegenError(pub Span, pub ErrorKind); impl From for syn::Error { fn from(CodegenError(span, kind): CodegenError) -> syn::Error { use ErrorKind::*; match kind { InvalidArgsLength { expected, found } => { syn::Error::new(span, format!("The number of args does not match the type signature provided: expected={expected}, found={found}")) } ConstructorRetValShouldBeVoid => { syn::Error::new(span, "Return type should be `void` (`V`) for constructor methods") } InvalidTypeSignature => syn::Error::new(span, "Failed to parse type signature"), } } } /// Codegen can fail with [`CodegenError`]. pub type CodegenResult = Result; /// Describes a method that will be generated. Create one with [`MethodCall::new`] and generate /// code with [`MethodCall::generate`]. This should be given AST nodes from the macro input so that /// errors are associated properly to the input span. pub struct MethodCall { env: Expr, method: MethodInfo, receiver: Receiver, arg_exprs: Vec, } impl MethodCall { /// Create a new MethodCall instance pub fn new(env: Expr, method: MethodInfo, receiver: Receiver, arg_exprs: Vec) -> Self { Self { env, method, receiver, arg_exprs, } } /// Generate code to call the described method. pub fn generate(&self) -> CodegenResult { // Needs to be threaded manually to other methods since self-referential structs can't // exist. let sig = self.method.sig()?; let args = self.generate_args(&sig)?; let method_call = self .receiver .generate_call(&self.env, &self.method, &sig, &args)?; // Wrap the generated code in a closure so that we can access the outer scope but any // variables we define aren't accessible by the outer scope. There is small hygiene issue // where the arg exprs have our `env` variable in scope. If this becomes an issue we can // refactor these exprs to be passed as closure parameters instead. Ok(quote! { (|| { #method_call })() }) } /// Generate the `&[jni::jvalue]` arguments slice that will be passed to the `jni` method call. /// This validates the argument count and types. fn generate_args(&self, sig: &MethodSig<'_>) -> CodegenResult { // Safety: must check that arg count matches the signature if self.arg_exprs.len() != sig.args.len() { return Err(CodegenError( self.method.sig.span(), ErrorKind::InvalidArgsLength { expected: sig.args.len(), found: self.arg_exprs.len(), }, )); } // Create each `jvalue` expression let type_expr_pairs = core::iter::zip(sig.args.iter().copied(), self.arg_exprs.iter()); let jvalues = type_expr_pairs.map(|(ty, expr)| generate_jvalue(ty, expr)); // Put the `jvalue` expressions in a slice. Ok(parse_quote! { &[#(#jvalues),*] }) } } /// The receiver of the method call and the type of the method. pub enum Receiver { /// A constructor. Constructor, /// A static method. Static, /// An instance method. The `Expr` here is the `this` object. Instance(Expr), } impl Receiver { /// Generate the code that performs the JNI call. fn generate_call( &self, env: &Expr, method_info: &MethodInfo, sig: &MethodSig<'_>, args: &Expr, ) -> CodegenResult { // Constructors are void methods. Validate this fact. if matches!(*self, Receiver::Constructor) && !sig.ret.is_void() { return Err(CodegenError( method_info.sig.span(), ErrorKind::ConstructorRetValShouldBeVoid, )); } // The static item containing the `pourover::[Static]MethodDesc`. let method_desc = self.generate_method_desc(method_info); // The `jni::signature::ReturnType` that the `jni` crate uses to perform the correct native // call. let return_type = return_type_from_sig(sig.ret); // A conversion expression to convert from `jni::object::JValueOwned` to the actual return // type. We have this information from the parsed method signature whereas the `jni` crate // only knows this at runtime. let conversion = return_value_conversion_from_sig(sig.ret); // This preamble is used to evaluate all the client-provided expressions outside of the // `unsafe` block. This is the same for all receiver kinds. let mut method_call = quote! { #method_desc let env: &mut ::jni::JNIEnv = #env; let method_id = ::jni::descriptors::Desc::lookup(&METHOD_DESC, env)?; let args: &[::jni::sys::jvalue] = #args; }; // Generate the unsafe JNI call. // // Safety: `args` contains the arguments to this method. The type signature of this // method is `#sig`. // // `args` must adhere to the following: // - `args.len()` must match the number of arguments given in the type signature. // - The union value of each arg in `args` must match the type specified in the type // signature. // // These conditions are upheld by this proc macro and a compile error will be caused if // they are broken. No user-provided code is executed within the `unsafe` block. method_call.append_all(match self { Self::Constructor => quote! { unsafe { env.new_object_unchecked( METHOD_DESC.cls(), method_id, args, ) } }, Self::Static => quote! { unsafe { env.call_static_method_unchecked( METHOD_DESC.cls(), method_id, #return_type, args, ) }#conversion }, Self::Instance(this) => quote! { let this_obj: &JObject = #this; unsafe { env.call_method_unchecked( this_obj, method_id, #return_type, args, ) }#conversion }, }); Ok(method_call) } fn generate_method_desc(&self, MethodInfo { cls, name, sig, .. }: &MethodInfo) -> ItemStatic { match self { Self::Constructor => parse_quote! { static METHOD_DESC: ::pourover::desc::MethodDesc = (#cls).constructor(#sig); }, Self::Static => parse_quote! { static METHOD_DESC: ::pourover::desc::StaticMethodDesc = (#cls).static_method(#name, #sig); }, Self::Instance(_) => parse_quote! { static METHOD_DESC: ::pourover::desc::MethodDesc = (#cls).method(#name, #sig); }, } } } /// Information about the method being called pub struct MethodInfo { cls: Expr, name: Expr, sig: LitStr, /// Derived from `sig.value()`. This string must be stored in the struct so that we can return /// a `MethodSig` instance that references it from `MethodInfo::sig()`. sig_str: String, } impl MethodInfo { pub fn new(cls: Expr, name: Expr, sig: LitStr) -> Self { let sig_str = sig.value(); Self { cls, name, sig, sig_str, } } /// Parse the type signature from `sig`. Will return a [`CodegenError`] if the signature cannot /// be parsed. fn sig(&self) -> CodegenResult> { MethodSig::try_from_str(&self.sig_str) .ok_or_else(|| CodegenError(self.sig.span(), ErrorKind::InvalidTypeSignature)) } } /// Generate a `jni::sys::jvalue` instance given a Java type and a Rust value. /// /// Safety: The generated `jvalue` must match the given type `ty`. fn generate_jvalue(ty: JavaType<'_>, expr: &Expr) -> TokenStream { // The `jvalue` field to inhabit let union_field: Ident; // The expected input type let type_name: Type; // Whether we need to call `JObject::as_raw()` on the input type let needs_as_raw: bool; // Fill the above values based the type signature. match ty { JavaType::Array { depth, ty } => { union_field = parse_quote![l]; if let NonArray::Primitive(p) = ty { if depth.get() == 1 { let prim_type = prim_to_sys_type(p); type_name = parse_quote![::jni::objects::JPrimitiveArray<'_, #prim_type>] } else { type_name = parse_quote![&::jni::objects::JObjectArray<'_>]; } } else { type_name = parse_quote![&::jni::objects::JObjectArray<'_>]; } needs_as_raw = true; } JavaType::NonArray(NonArray::Object { cls }) => { union_field = parse_quote![l]; type_name = match cls { "java/lang/String" => parse_quote![&::jni::objects::JString<'_>], "java/util/List" => parse_quote![&::jni::objects::JList<'_>], "java/util/Map" => parse_quote![&::jni::objects::JMap<'_>], _ => parse_quote![&::jni::objects::JObject<'_>], }; needs_as_raw = true; } JavaType::NonArray(NonArray::Primitive(p)) => { union_field = prim_to_union_field(p); type_name = prim_to_sys_type(p); needs_as_raw = false; } } // The as_raw() tokens if required. let as_raw = if needs_as_raw { quote! { .as_raw() } } else { quote![] }; // Create the `jvalue` expression. This uses `identity` to produce nice type error messages. quote_spanned! { expr.span() => ::jni::sys::jvalue { #union_field: ::core::convert::identity::<#type_name>(#expr) #as_raw } } } /// Get a `::jni::signature::ReturnType` expression from a [`crate::type_parser::ReturnType`]. This /// value is passed to the `jni` crate so that it knows which JNI method to call. fn return_type_from_sig(ret: ReturnType<'_>) -> Expr { let prim_type = |prim| parse_quote![::jni::signature::ReturnType::Primitive(::jni::signature::Primitive::#prim)]; use crate::type_parser::{JavaType::*, NonArray::*, Primitive::*}; match ret { ReturnType::Void => prim_type(quote![Void]), ReturnType::Returns(NonArray(Primitive(Boolean))) => prim_type(quote![Boolean]), ReturnType::Returns(NonArray(Primitive(Byte))) => prim_type(quote![Byte]), ReturnType::Returns(NonArray(Primitive(Char))) => prim_type(quote![Char]), ReturnType::Returns(NonArray(Primitive(Double))) => prim_type(quote![Double]), ReturnType::Returns(NonArray(Primitive(Float))) => prim_type(quote![Float]), ReturnType::Returns(NonArray(Primitive(Int))) => prim_type(quote![Int]), ReturnType::Returns(NonArray(Primitive(Long))) => prim_type(quote![Long]), ReturnType::Returns(NonArray(Primitive(Short))) => prim_type(quote![Short]), ReturnType::Returns(NonArray(Object { .. })) => { parse_quote![::jni::signature::ReturnType::Object] } ReturnType::Returns(Array { .. }) => parse_quote![::jni::signature::ReturnType::Array], } } /// A postfix call on a `jni::objects::JValueOwned` instance to convert it to the type specified by /// `ret`. Since we have this information from the type signature we can perform this conversion /// in the macro. fn return_value_conversion_from_sig(ret: ReturnType<'_>) -> TokenStream { use crate::type_parser::{JavaType::*, NonArray::*}; match ret { ReturnType::Void => quote! { .and_then(::jni::objects::JValueOwned::v) }, ReturnType::Returns(NonArray(Primitive(p))) => { let prim = prim_to_union_field(p); quote! { .and_then(::jni::objects::JValueOwned::#prim) } } ReturnType::Returns(NonArray(Object { cls })) => { let mut conversion = quote! { .and_then(::jni::objects::JValueOwned::l) }; match cls { "java/lang/String" => { conversion.append_all(quote! { .map(::jni::objects::JString::from) }); } "java/util/List" => { conversion.append_all(quote! { .map(::jni::objects::JList::from) }); } "java/util/Map" => { conversion.append_all(quote! { .map(::jni::objects::JMap::from) }); } _ => { // Already a JObject, so we are good here } } conversion } ReturnType::Returns(Array { depth, ty: Primitive(p), }) if depth.get() == 1 => { let sys_type = prim_to_sys_type(p); quote! { .and_then(::jni::objects::JValueOwned::l) .map(::jni::objects::JPrimitiveArray::<#sys_type>::from) } } ReturnType::Returns(Array { .. }) => quote! { .and_then(::jni::objects::JValueOwned::l) .map(::jni::objects::JObjectArray::from) }, } } /// From a [`Primitive`], this gets the `jni::sys::jvalue` union field name for that type. This is /// also the `jni::objects::JValueGen` getter name. fn prim_to_union_field(p: Primitive) -> Ident { quote::format_ident!("{}", p.as_char().to_ascii_lowercase()) } /// From a [`Primitive`], this gets the matching `jvalue::sys` type. fn prim_to_sys_type(p: Primitive) -> Type { match p { Primitive::Boolean => parse_quote![::jni::sys::jboolean], Primitive::Byte => parse_quote![::jni::sys::jbyte], Primitive::Char => parse_quote![::jni::sys::jchar], Primitive::Double => parse_quote![::jni::sys::jdouble], Primitive::Float => parse_quote![::jni::sys::jfloat], Primitive::Int => parse_quote![::jni::sys::jint], Primitive::Long => parse_quote![::jni::sys::jlong], Primitive::Short => parse_quote![::jni::sys::jshort], } } #[cfg(test)] #[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] mod tests { use super::*; use crate::test_util::contains_ident; use quote::ToTokens; use syn::parse_quote; fn example_method_call() -> MethodCall { MethodCall::new( parse_quote![&mut env], MethodInfo::new( parse_quote![&FOO_CLS], parse_quote!["example"], parse_quote!["(II)I"], ), Receiver::Instance(parse_quote![&foo]), vec![parse_quote![123], parse_quote![2 + 3]], ) } #[test] fn args_are_counted() { let mut call = example_method_call(); call.arg_exprs.push(parse_quote![too_many]); let CodegenError(_span, kind) = call.generate().unwrap_err(); assert_eq!( ErrorKind::InvalidArgsLength { expected: 2, found: 3 }, kind ); } #[test] fn constructor_return_type_is_void() { let mut call = example_method_call(); call.receiver = Receiver::Constructor; let CodegenError(_span, kind) = call.generate().unwrap_err(); assert_eq!(ErrorKind::ConstructorRetValShouldBeVoid, kind); } #[test] fn invalid_type_sig_is_error() { let mut call = example_method_call(); call.method.sig = parse_quote!["L"]; call.method.sig_str = call.method.sig.value(); let CodegenError(_span, kind) = call.generate().unwrap_err(); assert_eq!(ErrorKind::InvalidTypeSignature, kind); } #[test] fn jni_types_are_used_for_stdlib_classes_input() { let types = [ ("Ljava/lang/String;", "JString"), ("Ljava/util/Map;", "JMap"), ("Ljava/util/List;", "JList"), ("[Ljava/lang/String;", "JObjectArray"), ("[[I", "JObjectArray"), ("[I", "JPrimitiveArray"), ("Lcom/example/MyObject;", "JObject"), ("Z", "jboolean"), ("C", "jchar"), ("B", "jbyte"), ("S", "jshort"), ("I", "jint"), ("J", "jlong"), ("F", "jfloat"), ("D", "jdouble"), ]; for (desc, jni_type) in types { let jt = JavaType::try_from_str(desc).unwrap(); let expr = parse_quote![some_value]; let jvalue = generate_jvalue(jt, &expr); assert!( contains_ident(jvalue, jni_type), "desc: {desc}, jni_type: {jni_type}" ); } } #[test] fn jni_types_are_used_for_stdlib_classes_output() { let types = [ ("Ljava/lang/String;", "JString"), ("Ljava/util/Map;", "JMap"), ("Ljava/util/List;", "JList"), ("[Ljava/lang/String;", "JObjectArray"), ("[[I", "JObjectArray"), ("[I", "JPrimitiveArray"), ]; for (desc, jni_type) in types { let rt = ReturnType::try_from_str(desc).unwrap(); let conversion = return_value_conversion_from_sig(rt); assert!( contains_ident(conversion, jni_type), "desc: {desc}, jni_type: {jni_type}" ); } } #[test] fn return_type_passed_to_jni_is_correct() { let types = [ ("Ljava/lang/String;", "Object"), ("Ljava/util/Map;", "Object"), ("Ljava/util/List;", "Object"), ("[Ljava/lang/String;", "Array"), ("[[I", "Array"), ("[I", "Array"), ("V", "Void"), ("Z", "Boolean"), ("C", "Char"), ("B", "Byte"), ("S", "Short"), ("I", "Int"), ("J", "Long"), ("F", "Float"), ("D", "Double"), ]; for (desc, return_type) in types { let rt = ReturnType::try_from_str(desc).unwrap(); let expr = return_type_from_sig(rt).into_token_stream(); assert!( contains_ident(expr, return_type), "desc: {desc}, return_type: {return_type}" ); } } #[test] fn method_desc_is_correct() { let mut call = example_method_call(); call.method.sig = parse_quote!["(II)V"]; call.method.sig_str = call.method.sig.value(); let tests = [ (Receiver::Constructor, "constructor"), (Receiver::Static, "static_method"), (Receiver::Instance(parse_quote![this_value]), "method"), ]; for (receiver, method_ident) in tests { let desc = receiver.generate_method_desc(&call.method); let rhs = desc.expr.into_token_stream(); assert!(contains_ident(rhs, method_ident), "method: {method_ident}"); } } #[test] fn jni_call_is_correct() { let mut call = example_method_call(); call.method.sig = parse_quote!["(II)V"]; call.method.sig_str = call.method.sig.value(); let sig = call.method.sig().unwrap(); let args = parse_quote![test_stub]; let tests = [ (Receiver::Constructor, "new_object_unchecked"), (Receiver::Static, "call_static_method_unchecked"), ( Receiver::Instance(parse_quote![this_value]), "call_method_unchecked", ), ]; for (receiver, method_ident) in tests { let call = receiver .generate_call(&call.env, &call.method, &sig, &args) .unwrap(); assert!(contains_ident(call, method_ident), "method: {method_ident}"); } } }