• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "generate_rust.h"
18 
19 #include <android-base/stringprintf.h>
20 #include <android-base/strings.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24 
25 #include <map>
26 #include <memory>
27 #include <sstream>
28 
29 #include "aidl_to_cpp_common.h"
30 #include "aidl_to_rust.h"
31 #include "code_writer.h"
32 #include "comments.h"
33 #include "logging.h"
34 
35 using android::base::Join;
36 using std::ostringstream;
37 using std::shared_ptr;
38 using std::string;
39 using std::unique_ptr;
40 using std::vector;
41 
42 namespace android {
43 namespace aidl {
44 namespace rust {
45 
46 static constexpr const char kArgumentPrefix[] = "_arg_";
47 static constexpr const char kGetInterfaceVersion[] = "getInterfaceVersion";
48 static constexpr const char kGetInterfaceHash[] = "getInterfaceHash";
49 
GenerateMangledAlias(CodeWriter & out,const AidlDefinedType * type)50 void GenerateMangledAlias(CodeWriter& out, const AidlDefinedType* type) {
51   ostringstream alias;
52   for (const auto& component : type->GetSplitPackage()) {
53     alias << "_" << component.size() << "_" << component;
54   }
55   alias << "_" << type->GetName().size() << "_" << type->GetName();
56   out << "pub(crate) mod mangled { pub use super::" << type->GetName() << " as " << alias.str()
57       << "; }\n";
58 }
59 
BuildArg(const AidlArgument & arg,const AidlTypenames & typenames)60 string BuildArg(const AidlArgument& arg, const AidlTypenames& typenames) {
61   // We pass in parameters that are not primitives by const reference.
62   // Arrays get passed in as slices, which is handled in RustNameOf.
63   auto arg_mode = ArgumentStorageMode(arg, typenames);
64   auto arg_type = RustNameOf(arg.GetType(), typenames, arg_mode);
65   return kArgumentPrefix + arg.GetName() + ": " + arg_type;
66 }
67 
BuildMethod(const AidlMethod & method,const AidlTypenames & typenames)68 string BuildMethod(const AidlMethod& method, const AidlTypenames& typenames) {
69   auto method_type = RustNameOf(method.GetType(), typenames, StorageMode::VALUE);
70   auto return_type = string{"binder::public_api::Result<"} + method_type + ">";
71   string parameters = "&self";
72   for (const std::unique_ptr<AidlArgument>& arg : method.GetArguments()) {
73     parameters += ", ";
74     parameters += BuildArg(*arg, typenames);
75   }
76   return "fn " + method.GetName() + "(" + parameters + ") -> " + return_type;
77 }
78 
GenerateClientMethod(CodeWriter & out,const AidlInterface & iface,const AidlMethod & method,const AidlTypenames & typenames,const Options & options,const std::string & trait_name)79 void GenerateClientMethod(CodeWriter& out, const AidlInterface& iface, const AidlMethod& method,
80                           const AidlTypenames& typenames, const Options& options,
81                           const std::string& trait_name) {
82   // Generate the method
83   out << BuildMethod(method, typenames) << " {\n";
84   out.Indent();
85 
86   if (!method.IsUserDefined()) {
87     if (method.GetName() == kGetInterfaceVersion && options.Version() > 0) {
88       // Check if the version is in the cache
89       out << "let _aidl_version = "
90              "self.cached_version.load(std::sync::atomic::Ordering::Relaxed);\n";
91       out << "if _aidl_version != -1 { return Ok(_aidl_version); }\n";
92     }
93 
94     if (method.GetName() == kGetInterfaceHash && !options.Hash().empty()) {
95       out << "{\n";
96       out << "  let _aidl_hash_lock = self.cached_hash.lock().unwrap();\n";
97       out << "  if let Some(ref _aidl_hash) = *_aidl_hash_lock {\n";
98       out << "    return Ok(_aidl_hash.clone());\n";
99       out << "  }\n";
100       out << "}\n";
101     }
102   }
103 
104   // Call transact()
105   vector<string> flags;
106   if (method.IsOneway()) flags.push_back("binder::FLAG_ONEWAY");
107   if (iface.IsSensitiveData()) flags.push_back("binder::FLAG_CLEAR_BUF");
108   flags.push_back("binder::FLAG_PRIVATE_LOCAL");
109 
110   string transact_flags = flags.empty() ? "0" : Join(flags, " | ");
111   out << "let _aidl_reply = self.binder.transact("
112       << "transactions::" << method.GetName() << ", " << transact_flags << ", |_aidl_data| {\n";
113   out.Indent();
114 
115   if (iface.IsSensitiveData()) {
116     out << "_aidl_data.mark_sensitive();\n";
117   }
118 
119   // Arguments
120   for (const std::unique_ptr<AidlArgument>& arg : method.GetArguments()) {
121     auto arg_name = kArgumentPrefix + arg->GetName();
122     if (arg->IsIn()) {
123       // If the argument is already a reference, don't reference it again
124       // (unless we turned it into an Option<&T>)
125       auto ref_mode = ArgumentReferenceMode(*arg, typenames);
126       if (IsReference(ref_mode)) {
127         out << "_aidl_data.write(" << arg_name << ")?;\n";
128       } else {
129         out << "_aidl_data.write(&" << arg_name << ")?;\n";
130       }
131     } else if (arg->GetType().IsArray()) {
132       // For out-only arrays, send the array size
133       if (arg->GetType().IsNullable()) {
134         out << "_aidl_data.write_slice_size(" << arg_name << ".as_deref())?;\n";
135       } else {
136         out << "_aidl_data.write_slice_size(Some(" << arg_name << "))?;\n";
137       }
138     }
139   }
140 
141   // Return Ok(()) if all the `_aidl_data.write(...)?;` calls pass
142   out << "Ok(())\n";
143   out.Dedent();
144   out << "});\n";
145 
146   // Check for UNKNOWN_TRANSACTION and call the default impl
147   if (method.IsUserDefined()) {
148     string default_args;
149     for (const std::unique_ptr<AidlArgument>& arg : method.GetArguments()) {
150       if (!default_args.empty()) {
151         default_args += ", ";
152       }
153       default_args += kArgumentPrefix;
154       default_args += arg->GetName();
155     }
156     out << "if let Err(binder::StatusCode::UNKNOWN_TRANSACTION) = _aidl_reply {\n";
157     out << "  if let Some(_aidl_default_impl) = <Self as " << trait_name
158         << ">::getDefaultImpl() {\n";
159     out << "    return _aidl_default_impl." << method.GetName() << "(" << default_args << ");\n";
160     out << "  }\n";
161     out << "}\n";
162   }
163 
164   // Return all other errors
165   out << "let _aidl_reply = _aidl_reply?;\n";
166 
167   string return_val = "()";
168   if (!method.IsOneway()) {
169     // Check for errors
170     out << "let _aidl_status: binder::Status = _aidl_reply.read()?;\n";
171     out << "if !_aidl_status.is_ok() { return Err(_aidl_status); }\n";
172 
173     // Return reply value
174     if (method.GetType().GetName() != "void") {
175       auto return_type = RustNameOf(method.GetType(), typenames, StorageMode::VALUE);
176       out << "let _aidl_return: " << return_type << " = _aidl_reply.read()?;\n";
177       return_val = "_aidl_return";
178 
179       if (!method.IsUserDefined()) {
180         if (method.GetName() == kGetInterfaceVersion && options.Version() > 0) {
181           out << "self.cached_version.store(_aidl_return, std::sync::atomic::Ordering::Relaxed);\n";
182         }
183         if (method.GetName() == kGetInterfaceHash && !options.Hash().empty()) {
184           out << "*self.cached_hash.lock().unwrap() = Some(_aidl_return.clone());\n";
185         }
186       }
187     }
188 
189     for (const AidlArgument* arg : method.GetOutArguments()) {
190       out << "*" << kArgumentPrefix << arg->GetName() << " = _aidl_reply.read()?;\n";
191     }
192   }
193 
194   // Return the result
195   out << "Ok(" << return_val << ")\n";
196   out.Dedent();
197   out << "}\n";
198 }
199 
GenerateServerTransaction(CodeWriter & out,const AidlMethod & method,const AidlTypenames & typenames)200 void GenerateServerTransaction(CodeWriter& out, const AidlMethod& method,
201                                const AidlTypenames& typenames) {
202   out << "transactions::" << method.GetName() << " => {\n";
203   out.Indent();
204 
205   string args;
206   for (const auto& arg : method.GetArguments()) {
207     string arg_name = kArgumentPrefix + arg->GetName();
208     StorageMode arg_mode;
209     if (arg->IsIn()) {
210       arg_mode = StorageMode::VALUE;
211     } else {
212       // We need a value we can call Default::default() on
213       arg_mode = StorageMode::DEFAULT_VALUE;
214     }
215     auto arg_type = RustNameOf(arg->GetType(), typenames, arg_mode);
216 
217     string arg_mut = arg->IsOut() ? "mut " : "";
218     string arg_init = arg->IsIn() ? "_aidl_data.read()?" : "Default::default()";
219     out << "let " << arg_mut << arg_name << ": " << arg_type << " = " << arg_init << ";\n";
220     if (!arg->IsIn() && arg->GetType().IsArray()) {
221       // _aidl_data.resize_[nullable_]out_vec(&mut _arg_foo)?;
222       auto resize_name = arg->GetType().IsNullable() ? "resize_nullable_out_vec" : "resize_out_vec";
223       out << "_aidl_data." << resize_name << "(&mut " << arg_name << ")?;\n";
224     }
225 
226     auto ref_mode = ArgumentReferenceMode(*arg, typenames);
227     if (!args.empty()) {
228       args += ", ";
229     }
230     args += TakeReference(ref_mode, arg_name);
231   }
232   out << "let _aidl_return = _aidl_service." << method.GetName() << "(" << args << ");\n";
233 
234   if (!method.IsOneway()) {
235     out << "match &_aidl_return {\n";
236     out.Indent();
237     out << "Ok(_aidl_return) => {\n";
238     out.Indent();
239     out << "_aidl_reply.write(&binder::Status::from(binder::StatusCode::OK))?;\n";
240     if (method.GetType().GetName() != "void") {
241       out << "_aidl_reply.write(_aidl_return)?;\n";
242     }
243 
244     // Serialize out arguments
245     for (const AidlArgument* arg : method.GetOutArguments()) {
246       string arg_name = kArgumentPrefix + arg->GetName();
247 
248       auto& arg_type = arg->GetType();
249       if (!arg->IsIn() && arg_type.IsArray() && arg_type.GetName() == "ParcelFileDescriptor") {
250         // We represent arrays of ParcelFileDescriptor as
251         // Vec<Option<ParcelFileDescriptor>> when they're out-arguments,
252         // but we need all of them to be initialized to Some; if there's
253         // any None, return UNEXPECTED_NULL (this is what libbinder_ndk does)
254         out << "if " << arg_name << ".iter().any(Option::is_none) { "
255             << "return Err(binder::StatusCode::UNEXPECTED_NULL); }\n";
256       } else if (!arg->IsIn() && !TypeHasDefault(arg_type, typenames)) {
257         // Unwrap out-only arguments that we wrapped in Option<T>
258         out << "let " << arg_name << " = " << arg_name
259             << ".ok_or(binder::StatusCode::UNEXPECTED_NULL)?;\n";
260       }
261 
262       out << "_aidl_reply.write(&" << arg_name << ")?;\n";
263     }
264     out.Dedent();
265     out << "}\n";
266     out << "Err(_aidl_status) => _aidl_reply.write(_aidl_status)?\n";
267     out.Dedent();
268     out << "}\n";
269   }
270   out << "Ok(())\n";
271   out.Dedent();
272   out << "}\n";
273 }
274 
GenerateServerItems(CodeWriter & out,const AidlInterface * iface,const AidlTypenames & typenames)275 void GenerateServerItems(CodeWriter& out, const AidlInterface* iface,
276                          const AidlTypenames& typenames) {
277   auto trait_name = ClassName(*iface, cpp::ClassNames::INTERFACE);
278   auto server_name = ClassName(*iface, cpp::ClassNames::SERVER);
279 
280   // Forward all IFoo functions from Binder to the inner object
281   out << "impl " << trait_name << " for binder::Binder<" << server_name << "> {\n";
282   out.Indent();
283   for (const auto& method : iface->GetMethods()) {
284     string args;
285     for (const std::unique_ptr<AidlArgument>& arg : method->GetArguments()) {
286       if (!args.empty()) {
287         args += ", ";
288       }
289       args += kArgumentPrefix;
290       args += arg->GetName();
291     }
292     out << BuildMethod(*method, typenames) << " { "
293         << "self.0." << method->GetName() << "(" << args << ") }\n";
294   }
295   out.Dedent();
296   out << "}\n";
297 
298   out << "fn on_transact("
299          "_aidl_service: &dyn "
300       << trait_name
301       << ", "
302          "_aidl_code: binder::TransactionCode, "
303          "_aidl_data: &binder::parcel::Parcel, "
304          "_aidl_reply: &mut binder::parcel::Parcel) -> binder::Result<()> {\n";
305   out.Indent();
306   out << "match _aidl_code {\n";
307   out.Indent();
308   for (const auto& method : iface->GetMethods()) {
309     GenerateServerTransaction(out, *method, typenames);
310   }
311   out << "_ => Err(binder::StatusCode::UNKNOWN_TRANSACTION)\n";
312   out.Dedent();
313   out << "}\n";
314   out.Dedent();
315   out << "}\n";
316 }
317 
GenerateDeprecated(CodeWriter & out,const AidlCommentable & type)318 void GenerateDeprecated(CodeWriter& out, const AidlCommentable& type) {
319   if (auto deprecated = FindDeprecated(type.GetComments()); deprecated.has_value()) {
320     if (deprecated->note.empty()) {
321       out << "#[deprecated]\n";
322     } else {
323       out << "#[deprecated = " << QuotedEscape(deprecated->note) << "]\n";
324     }
325   }
326 }
327 
328 template <typename TypeWithConstants>
GenerateConstantDeclarations(CodeWriter & out,const TypeWithConstants & type,const AidlTypenames & typenames)329 void GenerateConstantDeclarations(CodeWriter& out, const TypeWithConstants& type,
330                                   const AidlTypenames& typenames) {
331   for (const auto& constant : type.GetConstantDeclarations()) {
332     const AidlTypeSpecifier& type = constant->GetType();
333     const AidlConstantValue& value = constant->GetValue();
334 
335     string const_type;
336     if (type.Signature() == "String") {
337       const_type = "&str";
338     } else if (type.Signature() == "byte" || type.Signature() == "int" ||
339                type.Signature() == "long") {
340       const_type = RustNameOf(type, typenames, StorageMode::VALUE);
341     } else {
342       AIDL_FATAL(value) << "Unrecognized constant type: " << type.Signature();
343     }
344 
345     GenerateDeprecated(out, *constant);
346     out << "pub const " << constant->GetName() << ": " << const_type << " = "
347         << constant->ValueString(ConstantValueDecoratorRef) << ";\n";
348   }
349 }
350 
GenerateRustInterface(const string & filename,const AidlInterface * iface,const AidlTypenames & typenames,const IoDelegate & io_delegate,const Options & options)351 bool GenerateRustInterface(const string& filename, const AidlInterface* iface,
352                            const AidlTypenames& typenames, const IoDelegate& io_delegate,
353                            const Options& options) {
354   CodeWriterPtr code_writer = io_delegate.GetCodeWriter(filename);
355 
356   *code_writer << "#![allow(non_upper_case_globals)]\n";
357   *code_writer << "#![allow(non_snake_case)]\n";
358   // Import IBinderInternal for transact()
359   *code_writer << "#[allow(unused_imports)] use binder::IBinderInternal;\n";
360 
361   auto trait_name = ClassName(*iface, cpp::ClassNames::INTERFACE);
362   auto client_name = ClassName(*iface, cpp::ClassNames::CLIENT);
363   auto server_name = ClassName(*iface, cpp::ClassNames::SERVER);
364   *code_writer << "use binder::declare_binder_interface;\n";
365   *code_writer << "declare_binder_interface! {\n";
366   code_writer->Indent();
367   *code_writer << trait_name << "[\"" << iface->GetDescriptor() << "\"] {\n";
368   code_writer->Indent();
369   *code_writer << "native: " << server_name << "(on_transact),\n";
370   *code_writer << "proxy: " << client_name << " {\n";
371   code_writer->Indent();
372   if (options.Version() > 0) {
373     string comma = options.Hash().empty() ? "" : ",";
374     *code_writer << "cached_version: "
375                     "std::sync::atomic::AtomicI32 = "
376                     "std::sync::atomic::AtomicI32::new(-1)"
377                  << comma << "\n";
378   }
379   if (!options.Hash().empty()) {
380     *code_writer << "cached_hash: "
381                     "std::sync::Mutex<Option<String>> = "
382                     "std::sync::Mutex::new(None)\n";
383   }
384   code_writer->Dedent();
385   *code_writer << "},\n";
386   code_writer->Dedent();
387   if (iface->IsVintfStability()) {
388     *code_writer << "stability: binder::Stability::Vintf,\n";
389   }
390   *code_writer << "}\n";
391   code_writer->Dedent();
392   *code_writer << "}\n";
393 
394   GenerateDeprecated(*code_writer, *iface);
395   *code_writer << "pub trait " << trait_name << ": binder::Interface + Send {\n";
396   code_writer->Indent();
397   *code_writer << "fn get_descriptor() -> &'static str where Self: Sized { \""
398                << iface->GetDescriptor() << "\" }\n";
399 
400   for (const auto& method : iface->GetMethods()) {
401     // Generate the method
402     GenerateDeprecated(*code_writer, *method);
403     if (method->IsUserDefined()) {
404       *code_writer << BuildMethod(*method, typenames) << ";\n";
405     } else {
406       // Generate default implementations for meta methods
407       *code_writer << BuildMethod(*method, typenames) << " {\n";
408       code_writer->Indent();
409       if (method->GetName() == kGetInterfaceVersion && options.Version() > 0) {
410         *code_writer << "Ok(VERSION)\n";
411       } else if (method->GetName() == kGetInterfaceHash && !options.Hash().empty()) {
412         *code_writer << "Ok(HASH.into())\n";
413       }
414       code_writer->Dedent();
415       *code_writer << "}\n";
416     }
417   }
418 
419   // Emit the default implementation code inside the trait
420   auto default_trait_name = ClassName(*iface, cpp::ClassNames::DEFAULT_IMPL);
421   auto default_ref_name = default_trait_name + "Ref";
422   *code_writer << "fn getDefaultImpl()"
423                << " -> " << default_ref_name << " where Self: Sized {\n";
424   *code_writer << "  DEFAULT_IMPL.lock().unwrap().clone()\n";
425   *code_writer << "}\n";
426   *code_writer << "fn setDefaultImpl(d: " << default_ref_name << ")"
427                << " -> " << default_ref_name << " where Self: Sized {\n";
428   *code_writer << "  std::mem::replace(&mut *DEFAULT_IMPL.lock().unwrap(), d)\n";
429   *code_writer << "}\n";
430   code_writer->Dedent();
431   *code_writer << "}\n";
432 
433   // Emit the default trait
434   *code_writer << "pub trait " << default_trait_name << ": Send + Sync {\n";
435   code_writer->Indent();
436   for (const auto& method : iface->GetMethods()) {
437     if (!method->IsUserDefined()) {
438       continue;
439     }
440 
441     // Generate the default method
442     *code_writer << BuildMethod(*method, typenames) << " {\n";
443     code_writer->Indent();
444     *code_writer << "Err(binder::StatusCode::UNKNOWN_TRANSACTION.into())\n";
445     code_writer->Dedent();
446     *code_writer << "}\n";
447   }
448   code_writer->Dedent();
449   *code_writer << "}\n";
450 
451   // Generate the transaction code constants
452   // The constants get their own sub-module to avoid conflicts
453   *code_writer << "pub mod transactions {\n";
454   code_writer->Indent();
455   for (const auto& method : iface->GetMethods()) {
456     // Generate the transaction code constant
457     *code_writer << "pub const " << method->GetName()
458                  << ": binder::TransactionCode = "
459                     "binder::FIRST_CALL_TRANSACTION + " +
460                         std::to_string(method->GetId()) + ";\n";
461   }
462   code_writer->Dedent();
463   *code_writer << "}\n";
464 
465   // Emit the default implementation code outside the trait
466   *code_writer << "pub type " << default_ref_name << " = Option<std::sync::Arc<dyn "
467                << default_trait_name << ">>;\n";
468   *code_writer << "use lazy_static::lazy_static;\n";
469   *code_writer << "lazy_static! {\n";
470   *code_writer << "  static ref DEFAULT_IMPL: std::sync::Mutex<" << default_ref_name
471                << "> = std::sync::Mutex::new(None);\n";
472   *code_writer << "}\n";
473 
474   // Emit the interface constants
475   GenerateConstantDeclarations(*code_writer, *iface, typenames);
476 
477   GenerateMangledAlias(*code_writer, iface);
478 
479   // Emit VERSION and HASH
480   // These need to be top-level item constants instead of associated consts
481   // because the latter are incompatible with trait objects, see
482   // https://doc.rust-lang.org/reference/items/traits.html#object-safety
483   if (options.Version() > 0) {
484     *code_writer << "pub const VERSION: i32 = " << std::to_string(options.Version()) << ";\n";
485   }
486   if (!options.Hash().empty()) {
487     *code_writer << "pub const HASH: &str = \"" << options.Hash() << "\";\n";
488   }
489 
490   // Generate the client-side methods
491   *code_writer << "impl " << trait_name << " for " << client_name << " {\n";
492   code_writer->Indent();
493   for (const auto& method : iface->GetMethods()) {
494     GenerateClientMethod(*code_writer, *iface, *method, typenames, options, trait_name);
495   }
496   code_writer->Dedent();
497   *code_writer << "}\n";
498 
499   // Generate the server-side methods
500   GenerateServerItems(*code_writer, iface, typenames);
501 
502   return true;
503 }
504 
GenerateParcelBody(CodeWriter & out,const AidlStructuredParcelable * parcel,const AidlTypenames & typenames)505 void GenerateParcelBody(CodeWriter& out, const AidlStructuredParcelable* parcel,
506                         const AidlTypenames& typenames) {
507   GenerateDeprecated(out, *parcel);
508   out << "pub struct " << parcel->GetName() << " {\n";
509   out.Indent();
510   for (const auto& variable : parcel->GetFields()) {
511     GenerateDeprecated(out, *variable);
512     auto field_type = RustNameOf(variable->GetType(), typenames, StorageMode::PARCELABLE_FIELD);
513     out << "pub " << variable->GetName() << ": " << field_type << ",\n";
514   }
515   out.Dedent();
516   out << "}\n";
517 }
518 
GenerateParcelDefault(CodeWriter & out,const AidlStructuredParcelable * parcel)519 void GenerateParcelDefault(CodeWriter& out, const AidlStructuredParcelable* parcel) {
520   out << "impl Default for " << parcel->GetName() << " {\n";
521   out.Indent();
522   out << "fn default() -> Self {\n";
523   out.Indent();
524   out << "Self {\n";
525   out.Indent();
526   for (const auto& variable : parcel->GetFields()) {
527     if (variable->GetDefaultValue()) {
528       out << variable->GetName() << ": " << variable->ValueString(ConstantValueDecorator) << ",\n";
529     } else {
530       out << variable->GetName() << ": Default::default(),\n";
531     }
532   }
533   out.Dedent();
534   out << "}\n";
535   out.Dedent();
536   out << "}\n";
537   out.Dedent();
538   out << "}\n";
539 }
540 
GenerateParcelSerializeBody(CodeWriter & out,const AidlStructuredParcelable * parcel,const AidlTypenames & typenames)541 void GenerateParcelSerializeBody(CodeWriter& out, const AidlStructuredParcelable* parcel,
542                                  const AidlTypenames& typenames) {
543   out << "parcel.sized_write(|subparcel| {\n";
544   out.Indent();
545   for (const auto& variable : parcel->GetFields()) {
546     if (!TypeHasDefault(variable->GetType(), typenames)) {
547       out << "let __field_ref = this." << variable->GetName()
548           << ".as_ref().ok_or(binder::StatusCode::UNEXPECTED_NULL)?;\n";
549       out << "subparcel.write(__field_ref)?;\n";
550     } else {
551       out << "subparcel.write(&this." << variable->GetName() << ")?;\n";
552     }
553   }
554   out << "Ok(())\n";
555   out.Dedent();
556   out << "})\n";
557 }
558 
GenerateParcelDeserializeBody(CodeWriter & out,const AidlStructuredParcelable * parcel,const AidlTypenames & typenames)559 void GenerateParcelDeserializeBody(CodeWriter& out, const AidlStructuredParcelable* parcel,
560                                    const AidlTypenames& typenames) {
561   out << "let start_pos = parcel.get_data_position();\n";
562   out << "let parcelable_size: i32 = parcel.read()?;\n";
563   out << "if parcelable_size < 0 { return Err(binder::StatusCode::BAD_VALUE); }\n";
564   out << "if start_pos.checked_add(parcelable_size).is_none() {\n";
565   out << "  return Err(binder::StatusCode::BAD_VALUE);\n";
566   out << "}\n";
567 
568   // Pre-emit the common field prologue code, shared between all fields:
569   ostringstream prologue;
570   prologue << "if (parcel.get_data_position() - start_pos) == parcelable_size {\n";
571   // We assume the lhs can never be > parcelable_size, because then the read
572   // immediately preceding this check would have returned NOT_ENOUGH_DATA
573   prologue << "  return Ok(Some(result));\n";
574   prologue << "}\n";
575   string prologue_str = prologue.str();
576 
577   out << "let mut result = Self::default();\n";
578   for (const auto& variable : parcel->GetFields()) {
579     out << prologue_str;
580     if (!TypeHasDefault(variable->GetType(), typenames)) {
581       out << "result." << variable->GetName() << " = Some(parcel.read()?);\n";
582     } else {
583       out << "result." << variable->GetName() << " = parcel.read()?;\n";
584     }
585   }
586   // Now we read all fields.
587   // Skip remaining data in case we're reading from a newer version
588   out << "unsafe {\n";
589   out << "  parcel.set_data_position(start_pos + parcelable_size)?;\n";
590   out << "}\n";
591   out << "Ok(Some(result))\n";
592 }
593 
GenerateParcelBody(CodeWriter & out,const AidlUnionDecl * parcel,const AidlTypenames & typenames)594 void GenerateParcelBody(CodeWriter& out, const AidlUnionDecl* parcel,
595                         const AidlTypenames& typenames) {
596   GenerateDeprecated(out, *parcel);
597   out << "pub enum " << parcel->GetName() << " {\n";
598   out.Indent();
599   for (const auto& variable : parcel->GetFields()) {
600     GenerateDeprecated(out, *variable);
601     auto field_type = RustNameOf(variable->GetType(), typenames, StorageMode::PARCELABLE_FIELD);
602     out << variable->GetCapitalizedName() << "(" << field_type << "),\n";
603   }
604   out.Dedent();
605   out << "}\n";
606 }
607 
GenerateParcelDefault(CodeWriter & out,const AidlUnionDecl * parcel)608 void GenerateParcelDefault(CodeWriter& out, const AidlUnionDecl* parcel) {
609   out << "impl Default for " << parcel->GetName() << " {\n";
610   out.Indent();
611   out << "fn default() -> Self {\n";
612   out.Indent();
613 
614   AIDL_FATAL_IF(parcel->GetFields().empty(), *parcel)
615       << "Union '" << parcel->GetName() << "' is empty.";
616   const auto& first_field = parcel->GetFields()[0];
617   const auto& first_value = first_field->ValueString(ConstantValueDecorator);
618 
619   out << "Self::";
620   if (first_field->GetDefaultValue()) {
621     out << first_field->GetCapitalizedName() << "(" << first_value << ")\n";
622   } else {
623     out << first_field->GetCapitalizedName() << "(Default::default())\n";
624   }
625 
626   out.Dedent();
627   out << "}\n";
628   out.Dedent();
629   out << "}\n";
630 }
631 
GenerateParcelSerializeBody(CodeWriter & out,const AidlUnionDecl * parcel,const AidlTypenames & typenames)632 void GenerateParcelSerializeBody(CodeWriter& out, const AidlUnionDecl* parcel,
633                                  const AidlTypenames& typenames) {
634   out << "match this {\n";
635   out.Indent();
636   int tag = 0;
637   for (const auto& variable : parcel->GetFields()) {
638     out << "Self::" << variable->GetCapitalizedName() << "(v) => {\n";
639     out.Indent();
640     out << "parcel.write(&" << std::to_string(tag++) << "i32)?;\n";
641     if (!TypeHasDefault(variable->GetType(), typenames)) {
642       out << "let __field_ref = v.as_ref().ok_or(binder::StatusCode::UNEXPECTED_NULL)?;\n";
643       out << "parcel.write(__field_ref)\n";
644     } else {
645       out << "parcel.write(v)\n";
646     }
647     out.Dedent();
648     out << "}\n";
649   }
650   out.Dedent();
651   out << "}\n";
652 }
653 
GenerateParcelDeserializeBody(CodeWriter & out,const AidlUnionDecl * parcel,const AidlTypenames & typenames)654 void GenerateParcelDeserializeBody(CodeWriter& out, const AidlUnionDecl* parcel,
655                                    const AidlTypenames& typenames) {
656   out << "let tag: i32 = parcel.read()?;\n";
657   out << "match tag {\n";
658   out.Indent();
659   int tag = 0;
660   for (const auto& variable : parcel->GetFields()) {
661     auto field_type = RustNameOf(variable->GetType(), typenames, StorageMode::PARCELABLE_FIELD);
662 
663     out << std::to_string(tag++) << " => {\n";
664     out.Indent();
665     out << "let value: " << field_type << " = ";
666     if (!TypeHasDefault(variable->GetType(), typenames)) {
667       out << "Some(parcel.read()?);\n";
668     } else {
669       out << "parcel.read()?;\n";
670     }
671     out << "Ok(Some(Self::" << variable->GetCapitalizedName() << "(value)))\n";
672     out.Dedent();
673     out << "}\n";
674   }
675   out << "_ => {\n";
676   out << "  Err(binder::StatusCode::BAD_VALUE)\n";
677   out << "}\n";
678   out.Dedent();
679   out << "}\n";
680 }
681 
682 template <typename ParcelableType>
GenerateParcelSerialize(CodeWriter & out,const ParcelableType * parcel,const AidlTypenames & typenames)683 void GenerateParcelSerialize(CodeWriter& out, const ParcelableType* parcel,
684                              const AidlTypenames& typenames) {
685   out << "impl binder::parcel::Serialize for " << parcel->GetName() << " {\n";
686   out << "  fn serialize(&self, parcel: &mut binder::parcel::Parcel) -> binder::Result<()> {\n";
687   out << "    <Self as binder::parcel::SerializeOption>::serialize_option(Some(self), parcel)\n";
688   out << "  }\n";
689   out << "}\n";
690 
691   out << "impl binder::parcel::SerializeArray for " << parcel->GetName() << " {}\n";
692 
693   out << "impl binder::parcel::SerializeOption for " << parcel->GetName() << " {\n";
694   out.Indent();
695   out << "fn serialize_option(this: Option<&Self>, parcel: &mut binder::parcel::Parcel) -> "
696          "binder::Result<()> {\n";
697   out.Indent();
698   out << "let this = if let Some(this) = this {\n";
699   out << "  parcel.write(&1i32)?;\n";
700   out << "  this\n";
701   out << "} else {\n";
702   out << "  return parcel.write(&0i32);\n";
703   out << "};\n";
704 
705   GenerateParcelSerializeBody(out, parcel, typenames);
706 
707   out.Dedent();
708   out << "}\n";
709   out.Dedent();
710   out << "}\n";
711 }
712 
713 template <typename ParcelableType>
GenerateParcelDeserialize(CodeWriter & out,const ParcelableType * parcel,const AidlTypenames & typenames)714 void GenerateParcelDeserialize(CodeWriter& out, const ParcelableType* parcel,
715                                const AidlTypenames& typenames) {
716   out << "impl binder::parcel::Deserialize for " << parcel->GetName() << " {\n";
717   out << "  fn deserialize(parcel: &binder::parcel::Parcel) -> binder::Result<Self> {\n";
718   out << "    <Self as binder::parcel::DeserializeOption>::deserialize_option(parcel)\n";
719   out << "       .transpose()\n";
720   out << "       .unwrap_or(Err(binder::StatusCode::UNEXPECTED_NULL))\n";
721   out << "  }\n";
722   out << "}\n";
723 
724   out << "impl binder::parcel::DeserializeArray for " << parcel->GetName() << " {}\n";
725 
726   out << "impl binder::parcel::DeserializeOption for " << parcel->GetName() << " {\n";
727   out.Indent();
728   out << "fn deserialize_option(parcel: &binder::parcel::Parcel) -> binder::Result<Option<Self>> "
729          "{\n";
730   out.Indent();
731   out << "let status: i32 = parcel.read()?;\n";
732   out << "if status == 0 { return Ok(None); }\n";
733 
734   GenerateParcelDeserializeBody(out, parcel, typenames);
735 
736   out.Dedent();
737   out << "}\n";
738   out.Dedent();
739   out << "}\n";
740 }
741 
742 template <typename ParcelableType>
GenerateRustParcel(const string & filename,const ParcelableType * parcel,const AidlTypenames & typenames,const IoDelegate & io_delegate)743 bool GenerateRustParcel(const string& filename, const ParcelableType* parcel,
744                         const AidlTypenames& typenames, const IoDelegate& io_delegate) {
745   CodeWriterPtr code_writer = io_delegate.GetCodeWriter(filename);
746 
747   // Debug is always derived because all Rust AIDL types implement it
748   // ParcelFileDescriptor doesn't support any of the others because
749   // it's a newtype over std::fs::File which only implements Debug
750   vector<string> derives{"Debug"};
751   const AidlAnnotation* derive_annotation = parcel->RustDerive();
752   if (derive_annotation != nullptr) {
753     for (const auto& name_and_param : derive_annotation->AnnotationParams(ConstantValueDecorator)) {
754       if (name_and_param.second == "true") {
755         derives.push_back(name_and_param.first);
756       }
757     }
758   }
759 
760   *code_writer << "#[derive(" << Join(derives, ", ") << ")]\n";
761   GenerateParcelBody(*code_writer, parcel, typenames);
762   GenerateConstantDeclarations(*code_writer, *parcel, typenames);
763   GenerateMangledAlias(*code_writer, parcel);
764   GenerateParcelDefault(*code_writer, parcel);
765   GenerateParcelSerialize(*code_writer, parcel, typenames);
766   GenerateParcelDeserialize(*code_writer, parcel, typenames);
767   return true;
768 }
769 
GenerateRustEnumDeclaration(const string & filename,const AidlEnumDeclaration * enum_decl,const AidlTypenames & typenames,const IoDelegate & io_delegate)770 bool GenerateRustEnumDeclaration(const string& filename, const AidlEnumDeclaration* enum_decl,
771                                  const AidlTypenames& typenames, const IoDelegate& io_delegate) {
772   CodeWriterPtr code_writer = io_delegate.GetCodeWriter(filename);
773 
774   const auto& aidl_backing_type = enum_decl->GetBackingType();
775   auto backing_type = RustNameOf(aidl_backing_type, typenames, StorageMode::VALUE);
776 
777   // TODO(b/177860423) support "deprecated" for enum types
778   *code_writer << "#![allow(non_upper_case_globals)]\n";
779   *code_writer << "use binder::declare_binder_enum;\n";
780   *code_writer << "declare_binder_enum! { " << enum_decl->GetName() << " : " << backing_type
781                << " {\n";
782   code_writer->Indent();
783   for (const auto& enumerator : enum_decl->GetEnumerators()) {
784     auto value = enumerator->GetValue()->ValueString(aidl_backing_type, ConstantValueDecorator);
785     *code_writer << enumerator->GetName() << " = " << value << ",\n";
786   }
787   code_writer->Dedent();
788   *code_writer << "} }\n";
789 
790   GenerateMangledAlias(*code_writer, enum_decl);
791 
792   return true;
793 }
794 
GenerateRust(const string & filename,const AidlDefinedType * defined_type,const AidlTypenames & typenames,const IoDelegate & io_delegate,const Options & options)795 bool GenerateRust(const string& filename, const AidlDefinedType* defined_type,
796                   const AidlTypenames& typenames, const IoDelegate& io_delegate,
797                   const Options& options) {
798   if (const AidlStructuredParcelable* parcelable = defined_type->AsStructuredParcelable();
799       parcelable != nullptr) {
800     return GenerateRustParcel(filename, parcelable, typenames, io_delegate);
801   }
802 
803   if (const AidlUnionDecl* parcelable = defined_type->AsUnionDeclaration(); parcelable != nullptr) {
804     return GenerateRustParcel(filename, parcelable, typenames, io_delegate);
805   }
806 
807   if (const AidlEnumDeclaration* enum_decl = defined_type->AsEnumDeclaration();
808       enum_decl != nullptr) {
809     return GenerateRustEnumDeclaration(filename, enum_decl, typenames, io_delegate);
810   }
811 
812   if (const AidlInterface* interface = defined_type->AsInterface(); interface != nullptr) {
813     return GenerateRustInterface(filename, interface, typenames, io_delegate, options);
814   }
815 
816   AIDL_FATAL(filename) << "Unrecognized type sent for Rust generation.";
817   return false;
818 }
819 
820 }  // namespace rust
821 }  // namespace aidl
822 }  // namespace android
823