1 /*
2 * Copyright (C) 2020, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "generate_rust.h"
18
19 #include <android-base/stringprintf.h>
20 #include <android-base/strings.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24
25 #include <map>
26 #include <memory>
27 #include <sstream>
28
29 #include "aidl_to_common.h"
30 #include "aidl_to_cpp_common.h"
31 #include "aidl_to_rust.h"
32 #include "code_writer.h"
33 #include "comments.h"
34 #include "logging.h"
35
36 using android::base::Join;
37 using android::base::Split;
38 using std::ostringstream;
39 using std::shared_ptr;
40 using std::string;
41 using std::unique_ptr;
42 using std::vector;
43
44 namespace android {
45 namespace aidl {
46 namespace rust {
47
48 static constexpr const char kArgumentPrefix[] = "_arg_";
49 static constexpr const char kGetInterfaceVersion[] = "getInterfaceVersion";
50 static constexpr const char kGetInterfaceHash[] = "getInterfaceHash";
51
52 struct MangledAliasVisitor : AidlVisitor {
53 CodeWriter& out;
MangledAliasVisitorandroid::aidl::rust::MangledAliasVisitor54 MangledAliasVisitor(CodeWriter& out) : out(out) {}
Visitandroid::aidl::rust::MangledAliasVisitor55 void Visit(const AidlStructuredParcelable& type) override { VisitType(type); }
Visitandroid::aidl::rust::MangledAliasVisitor56 void Visit(const AidlInterface& type) override { VisitType(type); }
Visitandroid::aidl::rust::MangledAliasVisitor57 void Visit(const AidlEnumDeclaration& type) override { VisitType(type); }
Visitandroid::aidl::rust::MangledAliasVisitor58 void Visit(const AidlUnionDecl& type) override { VisitType(type); }
59 template <typename T>
VisitTypeandroid::aidl::rust::MangledAliasVisitor60 void VisitType(const T& type) {
61 out << " pub use " << Qname(type) << " as " << Mangled(type) << ";\n";
62 }
63 // Return a mangled name for a type (including AIDL package)
64 template <typename T>
Mangledandroid::aidl::rust::MangledAliasVisitor65 string Mangled(const T& type) const {
66 ostringstream alias;
67 for (const auto& component : Split(type.GetCanonicalName(), ".")) {
68 alias << "_" << component.size() << "_" << component;
69 }
70 return alias.str();
71 }
72 template <typename T>
Typenameandroid::aidl::rust::MangledAliasVisitor73 string Typename(const T& type) const {
74 if constexpr (std::is_same_v<T, AidlInterface>) {
75 return ClassName(type, cpp::ClassNames::INTERFACE);
76 } else {
77 return type.GetName();
78 }
79 }
80 // Return a fully qualified name for a type in the current file (excluding AIDL package)
81 template <typename T>
Qnameandroid::aidl::rust::MangledAliasVisitor82 string Qname(const T& type) const {
83 return Module(type) + "::r#" + Typename(type);
84 }
85 // Return a module name for a type (relative to the file)
86 template <typename T>
Moduleandroid::aidl::rust::MangledAliasVisitor87 string Module(const T& type) const {
88 if (type.GetParentType()) {
89 return Module(*type.GetParentType()) + "::r#" + type.GetName();
90 } else {
91 return "super";
92 }
93 }
94 };
95
GenerateMangledAliases(CodeWriter & out,const AidlDefinedType & type)96 void GenerateMangledAliases(CodeWriter& out, const AidlDefinedType& type) {
97 MangledAliasVisitor v(out);
98 out << "pub(crate) mod mangled {\n";
99 VisitTopDown(v, type);
100 out << "}\n";
101 }
102
BuildArg(const AidlArgument & arg,const AidlTypenames & typenames,Lifetime lifetime)103 string BuildArg(const AidlArgument& arg, const AidlTypenames& typenames, Lifetime lifetime) {
104 // We pass in parameters that are not primitives by const reference.
105 // Arrays get passed in as slices, which is handled in RustNameOf.
106 auto arg_mode = ArgumentStorageMode(arg, typenames);
107 auto arg_type = RustNameOf(arg.GetType(), typenames, arg_mode, lifetime);
108 return kArgumentPrefix + arg.GetName() + ": " + arg_type;
109 }
110
111 enum class MethodKind {
112 // This is a normal non-async method.
113 NORMAL,
114 // This is an async method. Identical to NORMAL except that async is added
115 // in front of `fn`.
116 ASYNC,
117 // This is an async function, but using a boxed future instead of the async
118 // keyword.
119 BOXED_FUTURE,
120 // This could have been a non-async method, but it returns a Future so that
121 // it would not be breaking to make the function do async stuff in the future.
122 READY_FUTURE,
123 };
124
BuildMethod(const AidlMethod & method,const AidlTypenames & typenames,const MethodKind kind=MethodKind::NORMAL)125 string BuildMethod(const AidlMethod& method, const AidlTypenames& typenames,
126 const MethodKind kind = MethodKind::NORMAL) {
127 // We need to mark the arguments with a lifetime only when returning a future that
128 // actually captures the arguments.
129 Lifetime lifetime;
130 switch (kind) {
131 case MethodKind::NORMAL:
132 case MethodKind::ASYNC:
133 case MethodKind::READY_FUTURE:
134 lifetime = Lifetime::NONE;
135 break;
136 case MethodKind::BOXED_FUTURE:
137 lifetime = Lifetime::A;
138 break;
139 }
140
141 auto method_type = RustNameOf(method.GetType(), typenames, StorageMode::VALUE, lifetime);
142 auto return_type = string{"binder::Result<"} + method_type + ">";
143 auto fn_prefix = string{""};
144
145 switch (kind) {
146 case MethodKind::NORMAL:
147 // Don't wrap the return type in anything.
148 break;
149 case MethodKind::ASYNC:
150 fn_prefix = "async ";
151 break;
152 case MethodKind::BOXED_FUTURE:
153 return_type = "binder::BoxFuture<'a, " + return_type + ">";
154 break;
155 case MethodKind::READY_FUTURE:
156 return_type = "std::future::Ready<" + return_type + ">";
157 break;
158 }
159
160 string parameters = "&" + RustLifetimeName(lifetime) + "self";
161 string lifetime_str = RustLifetimeGeneric(lifetime);
162
163 for (const std::unique_ptr<AidlArgument>& arg : method.GetArguments()) {
164 parameters += ", ";
165 parameters += BuildArg(*arg, typenames, lifetime);
166 }
167
168 return fn_prefix + "fn r#" + method.GetName() + lifetime_str + "(" + parameters + ") -> " +
169 return_type;
170 }
171
GenerateClientMethodHelpers(CodeWriter & out,const AidlInterface & iface,const AidlMethod & method,const AidlTypenames & typenames,const Options & options,const std::string & default_trait_name)172 void GenerateClientMethodHelpers(CodeWriter& out, const AidlInterface& iface,
173 const AidlMethod& method, const AidlTypenames& typenames,
174 const Options& options, const std::string& default_trait_name) {
175 string parameters = "&self";
176 for (const std::unique_ptr<AidlArgument>& arg : method.GetArguments()) {
177 parameters += ", ";
178 parameters += BuildArg(*arg, typenames, Lifetime::NONE);
179 }
180
181 // Generate build_parcel helper.
182 out << "fn build_parcel_" + method.GetName() + "(" + parameters +
183 ") -> binder::Result<binder::binder_impl::Parcel> {\n";
184 out.Indent();
185
186 out << "let mut aidl_data = self.binder.prepare_transact()?;\n";
187
188 if (iface.IsSensitiveData()) {
189 out << "aidl_data.mark_sensitive();\n";
190 }
191
192 // Arguments
193 for (const std::unique_ptr<AidlArgument>& arg : method.GetArguments()) {
194 auto arg_name = kArgumentPrefix + arg->GetName();
195 if (arg->IsIn()) {
196 // If the argument is already a reference, don't reference it again
197 // (unless we turned it into an Option<&T>)
198 auto ref_mode = ArgumentReferenceMode(*arg, typenames);
199 if (IsReference(ref_mode)) {
200 out << "aidl_data.write(" << arg_name << ")?;\n";
201 } else {
202 out << "aidl_data.write(&" << arg_name << ")?;\n";
203 }
204 } else if (arg->GetType().IsDynamicArray()) {
205 // For out-only arrays, send the array size
206 if (arg->GetType().IsNullable()) {
207 out << "aidl_data.write_slice_size(" << arg_name << ".as_deref())?;\n";
208 } else {
209 out << "aidl_data.write_slice_size(Some(" << arg_name << "))?;\n";
210 }
211 }
212 }
213
214 out << "Ok(aidl_data)\n";
215 out.Dedent();
216 out << "}\n";
217
218 // Generate read_response helper.
219 auto return_type = RustNameOf(method.GetType(), typenames, StorageMode::VALUE, Lifetime::NONE);
220 out << "fn read_response_" + method.GetName() + "(" + parameters +
221 ", _aidl_reply: std::result::Result<binder::binder_impl::Parcel, "
222 "binder::StatusCode>) -> binder::Result<" +
223 return_type + "> {\n";
224 out.Indent();
225
226 // Check for UNKNOWN_TRANSACTION and call the default impl
227 if (method.IsUserDefined()) {
228 string default_args;
229 for (const std::unique_ptr<AidlArgument>& arg : method.GetArguments()) {
230 if (!default_args.empty()) {
231 default_args += ", ";
232 }
233 default_args += kArgumentPrefix;
234 default_args += arg->GetName();
235 }
236 out << "if let Err(binder::StatusCode::UNKNOWN_TRANSACTION) = _aidl_reply {\n";
237 out << " if let Some(_aidl_default_impl) = <Self as " << default_trait_name
238 << ">::getDefaultImpl() {\n";
239 out << " return _aidl_default_impl.r#" << method.GetName() << "(" << default_args << ");\n";
240 out << " }\n";
241 out << "}\n";
242 }
243
244 // Return all other errors
245 out << "let _aidl_reply = _aidl_reply?;\n";
246
247 string return_val = "()";
248 if (!method.IsOneway()) {
249 // Check for errors
250 out << "let _aidl_status: binder::Status = _aidl_reply.read()?;\n";
251 out << "if !_aidl_status.is_ok() { return Err(_aidl_status); }\n";
252
253 // Return reply value
254 if (method.GetType().GetName() != "void") {
255 auto return_type =
256 RustNameOf(method.GetType(), typenames, StorageMode::VALUE, Lifetime::NONE);
257 out << "let _aidl_return: " << return_type << " = _aidl_reply.read()?;\n";
258 return_val = "_aidl_return";
259
260 if (!method.IsUserDefined()) {
261 if (method.GetName() == kGetInterfaceVersion && options.Version() > 0) {
262 out << "self.cached_version.store(_aidl_return, std::sync::atomic::Ordering::Relaxed);\n";
263 }
264 if (method.GetName() == kGetInterfaceHash && !options.Hash().empty()) {
265 out << "*self.cached_hash.lock().unwrap() = Some(_aidl_return.clone());\n";
266 }
267 }
268 }
269
270 for (const AidlArgument* arg : method.GetOutArguments()) {
271 out << "_aidl_reply.read_onto(" << kArgumentPrefix << arg->GetName() << ")?;\n";
272 }
273 }
274
275 // Return the result
276 out << "Ok(" << return_val << ")\n";
277
278 out.Dedent();
279 out << "}\n";
280 }
281
GenerateClientMethod(CodeWriter & out,const AidlInterface & iface,const AidlMethod & method,const AidlTypenames & typenames,const Options & options,const MethodKind kind)282 void GenerateClientMethod(CodeWriter& out, const AidlInterface& iface, const AidlMethod& method,
283 const AidlTypenames& typenames, const Options& options,
284 const MethodKind kind) {
285 // Generate the method
286 out << BuildMethod(method, typenames, kind) << " {\n";
287 out.Indent();
288
289 if (!method.IsUserDefined()) {
290 if (method.GetName() == kGetInterfaceVersion && options.Version() > 0) {
291 // Check if the version is in the cache
292 out << "let _aidl_version = "
293 "self.cached_version.load(std::sync::atomic::Ordering::Relaxed);\n";
294 switch (kind) {
295 case MethodKind::NORMAL:
296 case MethodKind::ASYNC:
297 out << "if _aidl_version != -1 { return Ok(_aidl_version); }\n";
298 break;
299 case MethodKind::BOXED_FUTURE:
300 out << "if _aidl_version != -1 { return Box::pin(std::future::ready(Ok(_aidl_version))); "
301 "}\n";
302 break;
303 case MethodKind::READY_FUTURE:
304 out << "if _aidl_version != -1 { return std::future::ready(Ok(_aidl_version)); }\n";
305 break;
306 }
307 }
308
309 if (method.GetName() == kGetInterfaceHash && !options.Hash().empty()) {
310 out << "{\n";
311 out << " let _aidl_hash_lock = self.cached_hash.lock().unwrap();\n";
312 out << " if let Some(ref _aidl_hash) = *_aidl_hash_lock {\n";
313 switch (kind) {
314 case MethodKind::NORMAL:
315 case MethodKind::ASYNC:
316 out << " return Ok(_aidl_hash.clone());\n";
317 break;
318 case MethodKind::BOXED_FUTURE:
319 out << " return Box::pin(std::future::ready(Ok(_aidl_hash.clone())));\n";
320 break;
321 case MethodKind::READY_FUTURE:
322 out << " return std::future::ready(Ok(_aidl_hash.clone()));\n";
323 break;
324 }
325 out << " }\n";
326 out << "}\n";
327 }
328 }
329
330 string build_parcel_args;
331 for (const std::unique_ptr<AidlArgument>& arg : method.GetArguments()) {
332 if (!build_parcel_args.empty()) {
333 build_parcel_args += ", ";
334 }
335 build_parcel_args += kArgumentPrefix;
336 build_parcel_args += arg->GetName();
337 }
338
339 string read_response_args =
340 build_parcel_args.empty() ? "_aidl_reply" : build_parcel_args + ", _aidl_reply";
341
342 vector<string> flags;
343 if (method.IsOneway()) flags.push_back("binder::binder_impl::FLAG_ONEWAY");
344 if (iface.IsSensitiveData()) flags.push_back("binder::binder_impl::FLAG_CLEAR_BUF");
345 flags.push_back("binder::binder_impl::FLAG_PRIVATE_LOCAL");
346
347 string transact_flags = flags.empty() ? "0" : Join(flags, " | ");
348
349 switch (kind) {
350 case MethodKind::NORMAL:
351 case MethodKind::ASYNC:
352 if (method.IsNew() && ShouldForceDowngradeFor(CommunicationSide::WRITE) &&
353 method.IsUserDefined()) {
354 out << "if (true) {\n";
355 out << " return Err(binder::Status::from(binder::StatusCode::UNKNOWN_TRANSACTION));\n";
356 out << "} else {\n";
357 out.Indent();
358 }
359 // Prepare transaction.
360 out << "let _aidl_data = self.build_parcel_" + method.GetName() + "(" + build_parcel_args +
361 ")?;\n";
362 // Submit transaction.
363 out << "let _aidl_reply = self.binder.submit_transact(transactions::r#" << method.GetName()
364 << ", _aidl_data, " << transact_flags << ");\n";
365 // Deserialize response.
366 out << "self.read_response_" + method.GetName() + "(" + read_response_args + ")\n";
367 break;
368 case MethodKind::READY_FUTURE:
369 if (method.IsNew() && ShouldForceDowngradeFor(CommunicationSide::WRITE) &&
370 method.IsUserDefined()) {
371 out << "if (true) {\n";
372 out << " return "
373 "std::future::ready(Err(binder::Status::from(binder::StatusCode::UNKNOWN_"
374 "TRANSACTION)));\n";
375 out << "} else {\n";
376 out.Indent();
377 }
378 // Prepare transaction.
379 out << "let _aidl_data = match self.build_parcel_" + method.GetName() + "(" +
380 build_parcel_args + ") {\n";
381 out.Indent();
382 out << "Ok(_aidl_data) => _aidl_data,\n";
383 out << "Err(err) => return std::future::ready(Err(err)),\n";
384 out.Dedent();
385 out << "};\n";
386 // Submit transaction.
387 out << "let _aidl_reply = self.binder.submit_transact(transactions::r#" << method.GetName()
388 << ", _aidl_data, " << transact_flags << ");\n";
389 // Deserialize response.
390 out << "std::future::ready(self.read_response_" + method.GetName() + "(" +
391 read_response_args + "))\n";
392 break;
393 case MethodKind::BOXED_FUTURE:
394 if (method.IsNew() && ShouldForceDowngradeFor(CommunicationSide::WRITE) &&
395 method.IsUserDefined()) {
396 out << "if (true) {\n";
397 out << " return "
398 "Box::pin(std::future::ready(Err(binder::Status::from(binder::StatusCode::UNKNOWN_"
399 "TRANSACTION))));\n";
400 out << "} else {\n";
401 out.Indent();
402 }
403 // Prepare transaction.
404 out << "let _aidl_data = match self.build_parcel_" + method.GetName() + "(" +
405 build_parcel_args + ") {\n";
406 out.Indent();
407 out << "Ok(_aidl_data) => _aidl_data,\n";
408 out << "Err(err) => return Box::pin(std::future::ready(Err(err))),\n";
409 out.Dedent();
410 out << "};\n";
411 // Submit transaction.
412 out << "let binder = self.binder.clone();\n";
413 out << "P::spawn(\n";
414 out.Indent();
415 out << "move || binder.submit_transact(transactions::r#" << method.GetName()
416 << ", _aidl_data, " << transact_flags << "),\n";
417 out << "move |_aidl_reply| async move {\n";
418 out.Indent();
419 // Deserialize response.
420 out << "self.read_response_" + method.GetName() + "(" + read_response_args + ")\n";
421 out.Dedent();
422 out << "}\n";
423 out.Dedent();
424 out << ")\n";
425 break;
426 }
427
428 if (method.IsNew() && ShouldForceDowngradeFor(CommunicationSide::WRITE) &&
429 method.IsUserDefined()) {
430 out.Dedent();
431 out << "}\n";
432 }
433 out.Dedent();
434 out << "}\n";
435 }
436
GenerateServerTransaction(CodeWriter & out,const AidlInterface & interface,const AidlMethod & method,const AidlTypenames & typenames)437 void GenerateServerTransaction(CodeWriter& out, const AidlInterface& interface,
438 const AidlMethod& method, const AidlTypenames& typenames) {
439 out << "transactions::r#" << method.GetName() << " => {\n";
440 out.Indent();
441 if (method.IsUserDefined() && method.IsNew() &&
442 ShouldForceDowngradeFor(CommunicationSide::READ)) {
443 out << "if (true) {\n";
444 out << " Err(binder::StatusCode::UNKNOWN_TRANSACTION)\n";
445 out << "} else {\n";
446 out.Indent();
447 }
448
449 if (interface.EnforceExpression() || method.GetType().EnforceExpression()) {
450 out << "compile_error!(\"Permission checks not support for the Rust backend\");\n";
451 }
452
453 string args;
454 for (const auto& arg : method.GetArguments()) {
455 string arg_name = kArgumentPrefix + arg->GetName();
456 StorageMode arg_mode;
457 if (arg->IsIn()) {
458 arg_mode = StorageMode::VALUE;
459 } else {
460 // We need a value we can call Default::default() on
461 arg_mode = StorageMode::DEFAULT_VALUE;
462 }
463 auto arg_type = RustNameOf(arg->GetType(), typenames, arg_mode, Lifetime::NONE);
464
465 string arg_mut = arg->IsOut() ? "mut " : "";
466 string arg_init = arg->IsIn() ? "_aidl_data.read()?" : "Default::default()";
467 out << "let " << arg_mut << arg_name << ": " << arg_type << " = " << arg_init << ";\n";
468 if (!arg->IsIn() && arg->GetType().IsDynamicArray()) {
469 // _aidl_data.resize_[nullable_]out_vec(&mut _arg_foo)?;
470 auto resize_name = arg->GetType().IsNullable() ? "resize_nullable_out_vec" : "resize_out_vec";
471 out << "_aidl_data." << resize_name << "(&mut " << arg_name << ")?;\n";
472 }
473
474 auto ref_mode = ArgumentReferenceMode(*arg, typenames);
475 if (!args.empty()) {
476 args += ", ";
477 }
478 args += TakeReference(ref_mode, arg_name);
479 }
480 out << "let _aidl_return = _aidl_service.r#" << method.GetName() << "(" << args << ");\n";
481
482 if (!method.IsOneway()) {
483 out << "match &_aidl_return {\n";
484 out.Indent();
485 out << "Ok(_aidl_return) => {\n";
486 out.Indent();
487 out << "_aidl_reply.write(&binder::Status::from(binder::StatusCode::OK))?;\n";
488 if (method.GetType().GetName() != "void") {
489 out << "_aidl_reply.write(_aidl_return)?;\n";
490 }
491
492 // Serialize out arguments
493 for (const AidlArgument* arg : method.GetOutArguments()) {
494 string arg_name = kArgumentPrefix + arg->GetName();
495
496 auto& arg_type = arg->GetType();
497 if (!arg->IsIn() && arg_type.IsArray() && arg_type.GetName() == "ParcelFileDescriptor") {
498 // We represent arrays of ParcelFileDescriptor as
499 // Vec<Option<ParcelFileDescriptor>> when they're out-arguments,
500 // but we need all of them to be initialized to Some; if there's
501 // any None, return UNEXPECTED_NULL (this is what libbinder_ndk does)
502 out << "if " << arg_name << ".iter().any(Option::is_none) { "
503 << "return Err(binder::StatusCode::UNEXPECTED_NULL); }\n";
504 } else if (!arg->IsIn() && TypeNeedsOption(arg_type, typenames)) {
505 // Unwrap out-only arguments that we wrapped in Option<T>
506 out << "let " << arg_name << " = " << arg_name
507 << ".ok_or(binder::StatusCode::UNEXPECTED_NULL)?;\n";
508 }
509
510 out << "_aidl_reply.write(&" << arg_name << ")?;\n";
511 }
512 out.Dedent();
513 out << "}\n";
514 out << "Err(_aidl_status) => _aidl_reply.write(_aidl_status)?\n";
515 out.Dedent();
516 out << "}\n";
517 }
518 out << "Ok(())\n";
519 if (method.IsUserDefined() && method.IsNew() &&
520 ShouldForceDowngradeFor(CommunicationSide::READ)) {
521 out.Dedent();
522 out << "}\n";
523 }
524 out.Dedent();
525 out << "}\n";
526 }
527
GenerateServerItems(CodeWriter & out,const AidlInterface * iface,const AidlTypenames & typenames)528 void GenerateServerItems(CodeWriter& out, const AidlInterface* iface,
529 const AidlTypenames& typenames) {
530 auto trait_name = ClassName(*iface, cpp::ClassNames::INTERFACE);
531 auto server_name = ClassName(*iface, cpp::ClassNames::SERVER);
532
533 // Forward all IFoo functions from Binder to the inner object
534 out << "impl " << trait_name << " for binder::binder_impl::Binder<" << server_name << "> {\n";
535 out.Indent();
536 for (const auto& method : iface->GetMethods()) {
537 string args;
538 for (const std::unique_ptr<AidlArgument>& arg : method->GetArguments()) {
539 if (!args.empty()) {
540 args += ", ";
541 }
542 args += kArgumentPrefix;
543 args += arg->GetName();
544 }
545 out << BuildMethod(*method, typenames) << " { "
546 << "self.0.r#" << method->GetName() << "(" << args << ") }\n";
547 }
548 out.Dedent();
549 out << "}\n";
550
551 out << "fn on_transact("
552 "_aidl_service: &dyn "
553 << trait_name
554 << ", "
555 "_aidl_code: binder::binder_impl::TransactionCode, "
556 "_aidl_data: &binder::binder_impl::BorrowedParcel<'_>, "
557 "_aidl_reply: &mut binder::binder_impl::BorrowedParcel<'_>) -> std::result::Result<(), "
558 "binder::StatusCode> "
559 "{\n";
560 out.Indent();
561 out << "match _aidl_code {\n";
562 out.Indent();
563 for (const auto& method : iface->GetMethods()) {
564 GenerateServerTransaction(out, *iface, *method, typenames);
565 }
566 out << "_ => Err(binder::StatusCode::UNKNOWN_TRANSACTION)\n";
567 out.Dedent();
568 out << "}\n";
569 out.Dedent();
570 out << "}\n";
571 }
572
GenerateDeprecated(CodeWriter & out,const AidlCommentable & type)573 void GenerateDeprecated(CodeWriter& out, const AidlCommentable& type) {
574 if (auto deprecated = FindDeprecated(type.GetComments()); deprecated.has_value()) {
575 if (deprecated->note.empty()) {
576 out << "#[deprecated]\n";
577 } else {
578 out << "#[deprecated = " << QuotedEscape(deprecated->note) << "]\n";
579 }
580 }
581 }
582
583 template <typename TypeWithConstants>
GenerateConstantDeclarations(CodeWriter & out,const TypeWithConstants & type,const AidlTypenames & typenames)584 void GenerateConstantDeclarations(CodeWriter& out, const TypeWithConstants& type,
585 const AidlTypenames& typenames) {
586 for (const auto& constant : type.GetConstantDeclarations()) {
587 const AidlTypeSpecifier& type = constant->GetType();
588 const AidlConstantValue& value = constant->GetValue();
589
590 string const_type;
591 if (type.Signature() == "String") {
592 const_type = "&str";
593 } else if (type.Signature() == "byte" || type.Signature() == "int" ||
594 type.Signature() == "long" || type.Signature() == "float" ||
595 type.Signature() == "double") {
596 const_type = RustNameOf(type, typenames, StorageMode::VALUE, Lifetime::NONE);
597 } else {
598 AIDL_FATAL(value) << "Unrecognized constant type: " << type.Signature();
599 }
600
601 GenerateDeprecated(out, *constant);
602 out << "pub const r#" << constant->GetName() << ": " << const_type << " = "
603 << constant->ValueString(ConstantValueDecoratorRef) << ";\n";
604 }
605 }
606
GenerateRustInterface(CodeWriter * code_writer,const AidlInterface * iface,const AidlTypenames & typenames,const Options & options)607 void GenerateRustInterface(CodeWriter* code_writer, const AidlInterface* iface,
608 const AidlTypenames& typenames, const Options& options) {
609 *code_writer << "#![allow(non_upper_case_globals)]\n";
610 *code_writer << "#![allow(non_snake_case)]\n";
611 // Import IBinderInternal for transact()
612 *code_writer << "#[allow(unused_imports)] use binder::binder_impl::IBinderInternal;\n";
613
614 auto trait_name = ClassName(*iface, cpp::ClassNames::INTERFACE);
615 auto trait_name_async = trait_name + "Async";
616 auto trait_name_async_server = trait_name + "AsyncServer";
617 auto client_name = ClassName(*iface, cpp::ClassNames::CLIENT);
618 auto server_name = ClassName(*iface, cpp::ClassNames::SERVER);
619 *code_writer << "use binder::declare_binder_interface;\n";
620 *code_writer << "declare_binder_interface! {\n";
621 code_writer->Indent();
622 *code_writer << trait_name << "[\"" << iface->GetDescriptor() << "\"] {\n";
623 code_writer->Indent();
624 *code_writer << "native: " << server_name << "(on_transact),\n";
625 *code_writer << "proxy: " << client_name << " {\n";
626 code_writer->Indent();
627 if (options.Version() > 0) {
628 string comma = options.Hash().empty() ? "" : ",";
629 *code_writer << "cached_version: "
630 "std::sync::atomic::AtomicI32 = "
631 "std::sync::atomic::AtomicI32::new(-1)"
632 << comma << "\n";
633 }
634 if (!options.Hash().empty()) {
635 *code_writer << "cached_hash: "
636 "std::sync::Mutex<Option<String>> = "
637 "std::sync::Mutex::new(None)\n";
638 }
639 code_writer->Dedent();
640 *code_writer << "},\n";
641 *code_writer << "async: " << trait_name_async << ",\n";
642 if (iface->IsVintfStability()) {
643 *code_writer << "stability: binder::binder_impl::Stability::Vintf,\n";
644 }
645 code_writer->Dedent();
646 *code_writer << "}\n";
647 code_writer->Dedent();
648 *code_writer << "}\n";
649
650 // Emit the trait.
651 GenerateDeprecated(*code_writer, *iface);
652 *code_writer << "pub trait " << trait_name << ": binder::Interface + Send {\n";
653 code_writer->Indent();
654 *code_writer << "fn get_descriptor() -> &'static str where Self: Sized { \""
655 << iface->GetDescriptor() << "\" }\n";
656
657 for (const auto& method : iface->GetMethods()) {
658 // Generate the method
659 GenerateDeprecated(*code_writer, *method);
660 if (method->IsUserDefined()) {
661 *code_writer << BuildMethod(*method, typenames) << ";\n";
662 } else {
663 // Generate default implementations for meta methods
664 *code_writer << BuildMethod(*method, typenames) << " {\n";
665 code_writer->Indent();
666 if (method->GetName() == kGetInterfaceVersion && options.Version() > 0) {
667 *code_writer << "Ok(VERSION)\n";
668 } else if (method->GetName() == kGetInterfaceHash && !options.Hash().empty()) {
669 *code_writer << "Ok(HASH.into())\n";
670 }
671 code_writer->Dedent();
672 *code_writer << "}\n";
673 }
674 }
675
676 // Emit the default implementation code inside the trait
677 auto default_trait_name = ClassName(*iface, cpp::ClassNames::DEFAULT_IMPL);
678 auto default_ref_name = default_trait_name + "Ref";
679 *code_writer << "fn getDefaultImpl()"
680 << " -> " << default_ref_name << " where Self: Sized {\n";
681 *code_writer << " DEFAULT_IMPL.lock().unwrap().clone()\n";
682 *code_writer << "}\n";
683 *code_writer << "fn setDefaultImpl(d: " << default_ref_name << ")"
684 << " -> " << default_ref_name << " where Self: Sized {\n";
685 *code_writer << " std::mem::replace(&mut *DEFAULT_IMPL.lock().unwrap(), d)\n";
686 *code_writer << "}\n";
687 code_writer->Dedent();
688 *code_writer << "}\n";
689
690 // Emit the async trait.
691 GenerateDeprecated(*code_writer, *iface);
692 *code_writer << "pub trait " << trait_name_async << "<P>: binder::Interface + Send {\n";
693 code_writer->Indent();
694 *code_writer << "fn get_descriptor() -> &'static str where Self: Sized { \""
695 << iface->GetDescriptor() << "\" }\n";
696
697 for (const auto& method : iface->GetMethods()) {
698 // Generate the method
699 GenerateDeprecated(*code_writer, *method);
700
701 MethodKind kind = method->IsOneway() ? MethodKind::READY_FUTURE : MethodKind::BOXED_FUTURE;
702
703 if (method->IsUserDefined()) {
704 *code_writer << BuildMethod(*method, typenames, kind) << ";\n";
705 } else {
706 // Generate default implementations for meta methods
707 *code_writer << BuildMethod(*method, typenames, kind) << " {\n";
708 code_writer->Indent();
709 if (method->GetName() == kGetInterfaceVersion && options.Version() > 0) {
710 *code_writer << "Box::pin(async move { Ok(VERSION) })\n";
711 } else if (method->GetName() == kGetInterfaceHash && !options.Hash().empty()) {
712 *code_writer << "Box::pin(async move { Ok(HASH.into()) })\n";
713 }
714 code_writer->Dedent();
715 *code_writer << "}\n";
716 }
717 }
718 code_writer->Dedent();
719 *code_writer << "}\n";
720
721 // Emit the async server trait.
722 GenerateDeprecated(*code_writer, *iface);
723 *code_writer << "#[::async_trait::async_trait]\n";
724 *code_writer << "pub trait " << trait_name_async_server << ": binder::Interface + Send {\n";
725 code_writer->Indent();
726 *code_writer << "fn get_descriptor() -> &'static str where Self: Sized { \""
727 << iface->GetDescriptor() << "\" }\n";
728
729 for (const auto& method : iface->GetMethods()) {
730 // Generate the method
731 if (method->IsUserDefined()) {
732 GenerateDeprecated(*code_writer, *method);
733 *code_writer << BuildMethod(*method, typenames, MethodKind::ASYNC) << ";\n";
734 }
735 }
736 code_writer->Dedent();
737 *code_writer << "}\n";
738
739 // Emit a new_async_binder method for binding an async server.
740 *code_writer << "impl " << server_name << " {\n";
741 code_writer->Indent();
742 *code_writer << "/// Create a new async binder service.\n";
743 *code_writer << "pub fn new_async_binder<T, R>(inner: T, rt: R, features: "
744 "binder::BinderFeatures) -> binder::Strong<dyn "
745 << trait_name << ">\n";
746 *code_writer << "where\n";
747 code_writer->Indent();
748 *code_writer << "T: " << trait_name_async_server
749 << " + binder::Interface + Send + Sync + 'static,\n";
750 *code_writer << "R: binder::binder_impl::BinderAsyncRuntime + Send + Sync + 'static,\n";
751 code_writer->Dedent();
752 *code_writer << "{\n";
753 code_writer->Indent();
754 // Define a wrapper struct that implements the non-async trait by calling block_on.
755 *code_writer << "struct Wrapper<T, R> {\n";
756 code_writer->Indent();
757 *code_writer << "_inner: T,\n";
758 *code_writer << "_rt: R,\n";
759 code_writer->Dedent();
760 *code_writer << "}\n";
761 *code_writer << "impl<T, R> binder::Interface for Wrapper<T, R> where T: binder::Interface, R: "
762 "Send + Sync + 'static {\n";
763 code_writer->Indent();
764 *code_writer << "fn as_binder(&self) -> binder::SpIBinder { self._inner.as_binder() }\n";
765 *code_writer
766 << "fn dump(&self, _writer: &mut dyn std::io::Write, _args: "
767 "&[&std::ffi::CStr]) -> "
768 "std::result::Result<(), binder::StatusCode> { self._inner.dump(_writer, _args) }\n";
769 code_writer->Dedent();
770 *code_writer << "}\n";
771 *code_writer << "impl<T, R> " << trait_name << " for Wrapper<T, R>\n";
772 *code_writer << "where\n";
773 code_writer->Indent();
774 *code_writer << "T: " << trait_name_async_server << " + Send + Sync + 'static,\n";
775 *code_writer << "R: binder::binder_impl::BinderAsyncRuntime + Send + Sync + 'static,\n";
776 code_writer->Dedent();
777 *code_writer << "{\n";
778 code_writer->Indent();
779 for (const auto& method : iface->GetMethods()) {
780 // Generate the method
781 if (method->IsUserDefined()) {
782 string args = "";
783 for (const std::unique_ptr<AidlArgument>& arg : method->GetArguments()) {
784 if (!args.empty()) {
785 args += ", ";
786 }
787 args += kArgumentPrefix;
788 args += arg->GetName();
789 }
790
791 *code_writer << BuildMethod(*method, typenames) << " {\n";
792 code_writer->Indent();
793 *code_writer << "self._rt.block_on(self._inner.r#" << method->GetName() << "(" << args
794 << "))\n";
795 code_writer->Dedent();
796 *code_writer << "}\n";
797 }
798 }
799 code_writer->Dedent();
800 *code_writer << "}\n";
801
802 *code_writer << "let wrapped = Wrapper { _inner: inner, _rt: rt };\n";
803 *code_writer << "Self::new_binder(wrapped, features)\n";
804
805 code_writer->Dedent();
806 *code_writer << "}\n";
807 code_writer->Dedent();
808 *code_writer << "}\n";
809
810 // Emit the default trait
811 *code_writer << "pub trait " << default_trait_name << ": Send + Sync {\n";
812 code_writer->Indent();
813 for (const auto& method : iface->GetMethods()) {
814 if (!method->IsUserDefined()) {
815 continue;
816 }
817
818 // Generate the default method
819 *code_writer << BuildMethod(*method, typenames) << " {\n";
820 code_writer->Indent();
821 *code_writer << "Err(binder::StatusCode::UNKNOWN_TRANSACTION.into())\n";
822 code_writer->Dedent();
823 *code_writer << "}\n";
824 }
825 code_writer->Dedent();
826 *code_writer << "}\n";
827
828 // Generate the transaction code constants
829 // The constants get their own sub-module to avoid conflicts
830 *code_writer << "pub mod transactions {\n";
831 code_writer->Indent();
832 for (const auto& method : iface->GetMethods()) {
833 // Generate the transaction code constant
834 *code_writer << "pub const r#" << method->GetName()
835 << ": binder::binder_impl::TransactionCode = "
836 "binder::binder_impl::FIRST_CALL_TRANSACTION + " +
837 std::to_string(method->GetId()) + ";\n";
838 }
839 code_writer->Dedent();
840 *code_writer << "}\n";
841
842 // Emit the default implementation code outside the trait
843 *code_writer << "pub type " << default_ref_name << " = Option<std::sync::Arc<dyn "
844 << default_trait_name << ">>;\n";
845 *code_writer << "static DEFAULT_IMPL: std::sync::Mutex<" << default_ref_name
846 << "> = std::sync::Mutex::new(None);\n";
847
848 // Emit the interface constants
849 GenerateConstantDeclarations(*code_writer, *iface, typenames);
850
851 // Emit VERSION and HASH
852 // These need to be top-level item constants instead of associated consts
853 // because the latter are incompatible with trait objects, see
854 // https://doc.rust-lang.org/reference/items/traits.html#object-safety
855 if (options.Version() > 0) {
856 if (options.IsLatestUnfrozenVersion()) {
857 *code_writer << "pub const VERSION: i32 = if true {"
858 << std::to_string(options.PreviousVersion()) << "} else {"
859 << std::to_string(options.Version()) << "};\n";
860 } else {
861 *code_writer << "pub const VERSION: i32 = " << std::to_string(options.Version()) << ";\n";
862 }
863 }
864 if (!options.Hash().empty() || options.IsLatestUnfrozenVersion()) {
865 if (options.IsLatestUnfrozenVersion()) {
866 *code_writer << "pub const HASH: &str = if true {\"" << options.PreviousHash()
867 << "\"} else {\"" << options.Hash() << "\"};\n";
868 } else {
869 *code_writer << "pub const HASH: &str = \"" << options.Hash() << "\";\n";
870 }
871 }
872
873 // Generate the client-side method helpers
874 //
875 // The methods in this block are not marked pub, so they are not accessible from outside the
876 // AIDL generated code.
877 *code_writer << "impl " << client_name << " {\n";
878 code_writer->Indent();
879 for (const auto& method : iface->GetMethods()) {
880 GenerateClientMethodHelpers(*code_writer, *iface, *method, typenames, options, trait_name);
881 }
882 code_writer->Dedent();
883 *code_writer << "}\n";
884
885 // Generate the client-side methods
886 *code_writer << "impl " << trait_name << " for " << client_name << " {\n";
887 code_writer->Indent();
888 for (const auto& method : iface->GetMethods()) {
889 GenerateClientMethod(*code_writer, *iface, *method, typenames, options, MethodKind::NORMAL);
890 }
891 code_writer->Dedent();
892 *code_writer << "}\n";
893
894 // Generate the async client-side methods
895 *code_writer << "impl<P: binder::BinderAsyncPool> " << trait_name_async << "<P> for "
896 << client_name << " {\n";
897 code_writer->Indent();
898 for (const auto& method : iface->GetMethods()) {
899 MethodKind kind = method->IsOneway() ? MethodKind::READY_FUTURE : MethodKind::BOXED_FUTURE;
900 GenerateClientMethod(*code_writer, *iface, *method, typenames, options, kind);
901 }
902 code_writer->Dedent();
903 *code_writer << "}\n";
904
905 // Generate the server-side methods
906 GenerateServerItems(*code_writer, iface, typenames);
907 }
908
RemoveUsed(std::set<std::string> * params,const AidlTypeSpecifier & type)909 void RemoveUsed(std::set<std::string>* params, const AidlTypeSpecifier& type) {
910 if (!type.IsResolved()) {
911 params->erase(type.GetName());
912 }
913 if (type.IsGeneric()) {
914 for (const auto& param : type.GetTypeParameters()) {
915 RemoveUsed(params, *param);
916 }
917 }
918 }
919
FreeParams(const AidlStructuredParcelable * parcel)920 std::set<std::string> FreeParams(const AidlStructuredParcelable* parcel) {
921 if (!parcel->IsGeneric()) {
922 return std::set<std::string>();
923 }
924 auto typeParams = parcel->GetTypeParameters();
925 std::set<std::string> unusedParams(typeParams.begin(), typeParams.end());
926 for (const auto& variable : parcel->GetFields()) {
927 RemoveUsed(&unusedParams, variable->GetType());
928 }
929 return unusedParams;
930 }
931
WriteParams(CodeWriter & out,const AidlParameterizable<std::string> * parcel,std::string extra)932 void WriteParams(CodeWriter& out, const AidlParameterizable<std::string>* parcel,
933 std::string extra) {
934 if (parcel->IsGeneric()) {
935 out << "<";
936 for (const auto& param : parcel->GetTypeParameters()) {
937 out << param << extra << ",";
938 }
939 out << ">";
940 }
941 }
942
WriteParams(CodeWriter & out,const AidlParameterizable<std::string> * parcel)943 void WriteParams(CodeWriter& out, const AidlParameterizable<std::string>* parcel) {
944 WriteParams(out, parcel, "");
945 }
946
GeneratePaddingField(CodeWriter & out,const std::string & field_type,size_t struct_size,size_t & padding_index,const std::string & padding_element)947 static void GeneratePaddingField(CodeWriter& out, const std::string& field_type, size_t struct_size,
948 size_t& padding_index, const std::string& padding_element) {
949 // If current field is i64 or f64, generate padding for previous field. AIDL enums
950 // backed by these types have structs with alignment attributes generated so we only need to
951 // take primitive types that have variable alignment across archs into account here.
952 if (field_type == "i64" || field_type == "f64") {
953 // Align total struct size to 8 bytes since current field should have 8 byte alignment
954 auto padding_size = cpp::AlignTo(struct_size, 8) - struct_size;
955 if (padding_size != 0) {
956 out << "_pad_" << std::to_string(padding_index) << ": [" << padding_element << "; "
957 << std::to_string(padding_size) << "],\n";
958 padding_index += 1;
959 }
960 }
961 }
962
GenerateParcelBody(CodeWriter & out,const AidlStructuredParcelable * parcel,const AidlTypenames & typenames)963 void GenerateParcelBody(CodeWriter& out, const AidlStructuredParcelable* parcel,
964 const AidlTypenames& typenames) {
965 GenerateDeprecated(out, *parcel);
966 auto parcelable_alignment = cpp::AlignmentOfDefinedType(*parcel, typenames);
967 if (parcelable_alignment || parcel->IsFixedSize()) {
968 AIDL_FATAL_IF(!parcel->IsFixedSize(), parcel);
969 AIDL_FATAL_IF(parcelable_alignment == std::nullopt, parcel);
970 // i64/f64 are aligned to 4 bytes on x86 which may underalign the whole struct if it's the
971 // largest field so we need to set the alignment manually as if these types were aligned to 8
972 // bytes.
973 out << "#[repr(C, align(" << std::to_string(*parcelable_alignment) << "))]\n";
974 }
975 out << "pub struct r#" << parcel->GetName();
976 WriteParams(out, parcel);
977 out << " {\n";
978 out.Indent();
979 const auto& fields = parcel->GetFields();
980 // empty structs in C++ are 1 byte so generate an unused field in this case to make the layouts
981 // match
982 if (fields.size() == 0 && parcel->IsFixedSize()) {
983 out << "_unused: u8,\n";
984 } else {
985 size_t padding_index = 0;
986 size_t struct_size = 0;
987 for (const auto& variable : fields) {
988 GenerateDeprecated(out, *variable);
989 const auto& var_type = variable->GetType();
990 auto field_type =
991 RustNameOf(var_type, typenames, StorageMode::PARCELABLE_FIELD, Lifetime::NONE);
992 if (parcel->IsFixedSize()) {
993 GeneratePaddingField(out, field_type, struct_size, padding_index, "u8");
994
995 auto alignment = cpp::AlignmentOf(var_type, typenames);
996 AIDL_FATAL_IF(alignment == std::nullopt, var_type);
997 struct_size = cpp::AlignTo(struct_size, *alignment);
998 auto var_size = cpp::SizeOf(var_type, typenames);
999 AIDL_FATAL_IF(var_size == std::nullopt, var_type);
1000 struct_size += *var_size;
1001 }
1002 out << "pub r#" << variable->GetName() << ": " << field_type << ",\n";
1003 }
1004 for (const auto& unused_param : FreeParams(parcel)) {
1005 out << "_phantom_" << unused_param << ": std::marker::PhantomData<" << unused_param << ">,\n";
1006 }
1007 }
1008 out.Dedent();
1009 out << "}\n";
1010 if (parcel->IsFixedSize()) {
1011 size_t variable_offset = 0;
1012 for (const auto& variable : fields) {
1013 const auto& var_type = variable->GetType();
1014 // Assert the offset of each field within the struct
1015 auto alignment = cpp::AlignmentOf(var_type, typenames);
1016 AIDL_FATAL_IF(alignment == std::nullopt, var_type);
1017 variable_offset = cpp::AlignTo(variable_offset, *alignment);
1018 out << "static_assertions::const_assert_eq!(std::mem::offset_of!(" << parcel->GetName()
1019 << ", r#" << variable->GetName() << "), " << std::to_string(variable_offset) << ");\n";
1020
1021 // Assert the size of each field
1022 auto variable_size = cpp::SizeOf(var_type, typenames);
1023 AIDL_FATAL_IF(variable_size == std::nullopt, var_type);
1024 std::string rust_type =
1025 RustNameOf(var_type, typenames, StorageMode::PARCELABLE_FIELD, Lifetime::NONE);
1026 out << "static_assertions::const_assert_eq!(std::mem::size_of::<" << rust_type << ">(), "
1027 << std::to_string(*variable_size) << ");\n";
1028
1029 variable_offset += *variable_size;
1030 }
1031 // Assert the alignment of the struct
1032 auto parcelable_alignment = cpp::AlignmentOfDefinedType(*parcel, typenames);
1033 AIDL_FATAL_IF(parcelable_alignment == std::nullopt, *parcel);
1034 out << "static_assertions::const_assert_eq!(std::mem::align_of::<" << parcel->GetName()
1035 << ">(), " << std::to_string(*parcelable_alignment) << ");\n";
1036
1037 // Assert the size of the struct
1038 auto parcelable_size = cpp::SizeOfDefinedType(*parcel, typenames);
1039 AIDL_FATAL_IF(parcelable_size == std::nullopt, *parcel);
1040 out << "static_assertions::const_assert_eq!(std::mem::size_of::<" << parcel->GetName()
1041 << ">(), " << std::to_string(*parcelable_size) << ");\n";
1042 }
1043 }
1044
GenerateParcelDefault(CodeWriter & out,const AidlStructuredParcelable * parcel,const AidlTypenames & typenames)1045 void GenerateParcelDefault(CodeWriter& out, const AidlStructuredParcelable* parcel,
1046 const AidlTypenames& typenames) {
1047 out << "impl";
1048 WriteParams(out, parcel, ": Default");
1049 out << " Default for r#" << parcel->GetName();
1050 WriteParams(out, parcel);
1051 out << " {\n";
1052 out.Indent();
1053 out << "fn default() -> Self {\n";
1054 out.Indent();
1055 out << "Self {\n";
1056 out.Indent();
1057 size_t padding_index = 0;
1058 size_t struct_size = 0;
1059 const auto& fields = parcel->GetFields();
1060 if (fields.size() == 0 && parcel->IsFixedSize()) {
1061 out << "_unused: 0,\n";
1062 } else {
1063 for (const auto& variable : fields) {
1064 const auto& var_type = variable->GetType();
1065 // Generate initializer for padding for previous field if current field is i64 or f64
1066 if (parcel->IsFixedSize()) {
1067 auto field_type =
1068 RustNameOf(var_type, typenames, StorageMode::PARCELABLE_FIELD, Lifetime::NONE);
1069 GeneratePaddingField(out, field_type, struct_size, padding_index, "0");
1070
1071 auto alignment = cpp::AlignmentOf(var_type, typenames);
1072 AIDL_FATAL_IF(alignment == std::nullopt, var_type);
1073 struct_size = cpp::AlignTo(struct_size, *alignment);
1074
1075 auto var_size = cpp::SizeOf(var_type, typenames);
1076 AIDL_FATAL_IF(var_size == std::nullopt, var_type);
1077 struct_size += *var_size;
1078 }
1079
1080 out << "r#" << variable->GetName() << ": ";
1081 if (variable->GetDefaultValue()) {
1082 out << variable->ValueString(ConstantValueDecorator);
1083 } else {
1084 // Some types don't implement "Default".
1085 // - ParcelableHolder
1086 // - Arrays
1087 if (variable->GetType().GetName() == "ParcelableHolder") {
1088 out << "binder::ParcelableHolder::new(";
1089 if (parcel->IsVintfStability()) {
1090 out << "binder::binder_impl::Stability::Vintf";
1091 } else {
1092 out << "binder::binder_impl::Stability::Local";
1093 }
1094 out << ")";
1095 } else if (variable->GetType().IsFixedSizeArray() && !variable->GetType().IsNullable()) {
1096 out << ArrayDefaultValue(variable->GetType());
1097 } else {
1098 out << "Default::default()";
1099 }
1100 }
1101 out << ",\n";
1102 }
1103 for (const auto& unused_param : FreeParams(parcel)) {
1104 out << "r#_phantom_" << unused_param << ": Default::default(),\n";
1105 }
1106 }
1107 out.Dedent();
1108 out << "}\n";
1109 out.Dedent();
1110 out << "}\n";
1111 out.Dedent();
1112 out << "}\n";
1113 }
1114
GenerateParcelSerializeBody(CodeWriter & out,const AidlStructuredParcelable * parcel,const AidlTypenames & typenames)1115 void GenerateParcelSerializeBody(CodeWriter& out, const AidlStructuredParcelable* parcel,
1116 const AidlTypenames& typenames) {
1117 out << "parcel.sized_write(|subparcel| {\n";
1118 out.Indent();
1119 for (const auto& variable : parcel->GetFields()) {
1120 if (variable->IsNew() && ShouldForceDowngradeFor(CommunicationSide::WRITE)) {
1121 out << "if (false) {\n";
1122 out.Indent();
1123 }
1124 if (TypeNeedsOption(variable->GetType(), typenames)) {
1125 out << "let __field_ref = self.r#" << variable->GetName()
1126 << ".as_ref().ok_or(binder::StatusCode::UNEXPECTED_NULL)?;\n";
1127 out << "subparcel.write(__field_ref)?;\n";
1128 } else {
1129 out << "subparcel.write(&self.r#" << variable->GetName() << ")?;\n";
1130 }
1131 if (variable->IsNew() && ShouldForceDowngradeFor(CommunicationSide::WRITE)) {
1132 out.Dedent();
1133 out << "}\n";
1134 }
1135 }
1136 out << "Ok(())\n";
1137 out.Dedent();
1138 out << "})\n";
1139 }
1140
GenerateParcelDeserializeBody(CodeWriter & out,const AidlStructuredParcelable * parcel,const AidlTypenames & typenames)1141 void GenerateParcelDeserializeBody(CodeWriter& out, const AidlStructuredParcelable* parcel,
1142 const AidlTypenames& typenames) {
1143 out << "parcel.sized_read(|subparcel| {\n";
1144 out.Indent();
1145
1146 for (const auto& variable : parcel->GetFields()) {
1147 if (variable->IsNew() && ShouldForceDowngradeFor(CommunicationSide::READ)) {
1148 out << "if (false) {\n";
1149 out.Indent();
1150 }
1151 out << "if subparcel.has_more_data() {\n";
1152 out.Indent();
1153 if (TypeNeedsOption(variable->GetType(), typenames)) {
1154 out << "self.r#" << variable->GetName() << " = Some(subparcel.read()?);\n";
1155 } else {
1156 out << "self.r#" << variable->GetName() << " = subparcel.read()?;\n";
1157 }
1158 if (variable->IsNew() && ShouldForceDowngradeFor(CommunicationSide::READ)) {
1159 out.Dedent();
1160 out << "}\n";
1161 }
1162 out.Dedent();
1163 out << "}\n";
1164 }
1165 out << "Ok(())\n";
1166 out.Dedent();
1167 out << "})\n";
1168 }
1169
GenerateParcelBody(CodeWriter & out,const AidlUnionDecl * parcel,const AidlTypenames & typenames)1170 void GenerateParcelBody(CodeWriter& out, const AidlUnionDecl* parcel,
1171 const AidlTypenames& typenames) {
1172 GenerateDeprecated(out, *parcel);
1173 auto alignment = cpp::AlignmentOfDefinedType(*parcel, typenames);
1174 if (parcel->IsFixedSize()) {
1175 AIDL_FATAL_IF(alignment == std::nullopt, *parcel);
1176 auto tag = std::to_string(*alignment * 8);
1177 // This repr may use a tag larger than u8 to make sure the tag padding takes into account that
1178 // the overall alignment is computed as if i64/f64 were always 8-byte aligned
1179 out << "#[repr(C, u" << tag << ", align(" << std::to_string(*alignment) << "))]\n";
1180 }
1181 out << "pub enum r#" << parcel->GetName() << " {\n";
1182 out.Indent();
1183 for (const auto& variable : parcel->GetFields()) {
1184 GenerateDeprecated(out, *variable);
1185 auto field_type =
1186 RustNameOf(variable->GetType(), typenames, StorageMode::PARCELABLE_FIELD, Lifetime::NONE);
1187 out << variable->GetCapitalizedName() << "(" << field_type << "),\n";
1188 }
1189 out.Dedent();
1190 out << "}\n";
1191 if (parcel->IsFixedSize()) {
1192 for (const auto& variable : parcel->GetFields()) {
1193 const auto& var_type = variable->GetType();
1194 std::string rust_type =
1195 RustNameOf(var_type, typenames, StorageMode::PARCELABLE_FIELD, Lifetime::NONE);
1196 // Assert the size of each enum variant's payload
1197 auto variable_size = cpp::SizeOf(var_type, typenames);
1198 AIDL_FATAL_IF(variable_size == std::nullopt, var_type);
1199 out << "static_assertions::const_assert_eq!(std::mem::size_of::<" << rust_type << ">(), "
1200 << std::to_string(*variable_size) << ");\n";
1201 }
1202 // Assert the alignment of the enum
1203 AIDL_FATAL_IF(alignment == std::nullopt, *parcel);
1204 out << "static_assertions::const_assert_eq!(std::mem::align_of::<" << parcel->GetName()
1205 << ">(), " << std::to_string(*alignment) << ");\n";
1206
1207 // Assert the size of the enum, taking into the tag and its padding into account
1208 auto union_size = cpp::SizeOfDefinedType(*parcel, typenames);
1209 AIDL_FATAL_IF(union_size == std::nullopt, *parcel);
1210 out << "static_assertions::const_assert_eq!(std::mem::size_of::<" << parcel->GetName()
1211 << ">(), " << std::to_string(*union_size) << ");\n";
1212 }
1213 }
1214
GenerateParcelDefault(CodeWriter & out,const AidlUnionDecl * parcel,const AidlTypenames & typenames)1215 void GenerateParcelDefault(CodeWriter& out, const AidlUnionDecl* parcel,
1216 const AidlTypenames& typenames __attribute__((unused))) {
1217 out << "impl";
1218 WriteParams(out, parcel, ": Default");
1219 out << " Default for r#" << parcel->GetName();
1220 WriteParams(out, parcel);
1221 out << " {\n";
1222 out.Indent();
1223 out << "fn default() -> Self {\n";
1224 out.Indent();
1225
1226 AIDL_FATAL_IF(parcel->GetFields().empty(), *parcel)
1227 << "Union '" << parcel->GetName() << "' is empty.";
1228 const auto& first_field = parcel->GetFields()[0];
1229 const auto& first_value = first_field->ValueString(ConstantValueDecorator);
1230
1231 out << "Self::";
1232 if (first_field->GetDefaultValue()) {
1233 out << first_field->GetCapitalizedName() << "(" << first_value << ")\n";
1234 } else {
1235 out << first_field->GetCapitalizedName() << "(Default::default())\n";
1236 }
1237
1238 out.Dedent();
1239 out << "}\n";
1240 out.Dedent();
1241 out << "}\n";
1242 }
1243
GenerateParcelSerializeBody(CodeWriter & out,const AidlUnionDecl * parcel,const AidlTypenames & typenames)1244 void GenerateParcelSerializeBody(CodeWriter& out, const AidlUnionDecl* parcel,
1245 const AidlTypenames& typenames) {
1246 out << "match self {\n";
1247 out.Indent();
1248 int tag = 0;
1249 for (const auto& variable : parcel->GetFields()) {
1250 out << "Self::" << variable->GetCapitalizedName() << "(v) => {\n";
1251 out.Indent();
1252 if (variable->IsNew() && ShouldForceDowngradeFor(CommunicationSide::WRITE)) {
1253 out << "if (true) {\n";
1254 out << " Err(binder::StatusCode::BAD_VALUE)\n";
1255 out << "} else {\n";
1256 out.Indent();
1257 }
1258 out << "parcel.write(&" << std::to_string(tag++) << "i32)?;\n";
1259 if (TypeNeedsOption(variable->GetType(), typenames)) {
1260 out << "let __field_ref = v.as_ref().ok_or(binder::StatusCode::UNEXPECTED_NULL)?;\n";
1261 out << "parcel.write(__field_ref)\n";
1262 } else {
1263 out << "parcel.write(v)\n";
1264 }
1265 if (variable->IsNew() && ShouldForceDowngradeFor(CommunicationSide::WRITE)) {
1266 out.Dedent();
1267 out << "}\n";
1268 }
1269 out.Dedent();
1270 out << "}\n";
1271 }
1272 out.Dedent();
1273 out << "}\n";
1274 }
1275
GenerateParcelDeserializeBody(CodeWriter & out,const AidlUnionDecl * parcel,const AidlTypenames & typenames)1276 void GenerateParcelDeserializeBody(CodeWriter& out, const AidlUnionDecl* parcel,
1277 const AidlTypenames& typenames) {
1278 out << "let tag: i32 = parcel.read()?;\n";
1279 out << "match tag {\n";
1280 out.Indent();
1281 int tag = 0;
1282 for (const auto& variable : parcel->GetFields()) {
1283 auto field_type =
1284 RustNameOf(variable->GetType(), typenames, StorageMode::PARCELABLE_FIELD, Lifetime::NONE);
1285
1286 out << std::to_string(tag++) << " => {\n";
1287 out.Indent();
1288 if (variable->IsNew() && ShouldForceDowngradeFor(CommunicationSide::READ)) {
1289 out << "if (true) {\n";
1290 out << " Err(binder::StatusCode::BAD_VALUE)\n";
1291 out << "} else {\n";
1292 out.Indent();
1293 }
1294 out << "let value: " << field_type << " = ";
1295 if (TypeNeedsOption(variable->GetType(), typenames)) {
1296 out << "Some(parcel.read()?);\n";
1297 } else {
1298 out << "parcel.read()?;\n";
1299 }
1300 out << "*self = Self::" << variable->GetCapitalizedName() << "(value);\n";
1301 out << "Ok(())\n";
1302 if (variable->IsNew() && ShouldForceDowngradeFor(CommunicationSide::READ)) {
1303 out.Dedent();
1304 out << "}\n";
1305 }
1306 out.Dedent();
1307 out << "}\n";
1308 }
1309 out << "_ => {\n";
1310 out << " Err(binder::StatusCode::BAD_VALUE)\n";
1311 out << "}\n";
1312 out.Dedent();
1313 out << "}\n";
1314 }
1315
1316 template <typename ParcelableType>
GenerateParcelableTrait(CodeWriter & out,const ParcelableType * parcel,const AidlTypenames & typenames)1317 void GenerateParcelableTrait(CodeWriter& out, const ParcelableType* parcel,
1318 const AidlTypenames& typenames) {
1319 out << "impl";
1320 WriteParams(out, parcel);
1321 out << " binder::Parcelable for r#" << parcel->GetName();
1322 WriteParams(out, parcel);
1323 out << " {\n";
1324 out.Indent();
1325
1326 out << "fn write_to_parcel(&self, "
1327 "parcel: &mut binder::binder_impl::BorrowedParcel) -> std::result::Result<(), "
1328 "binder::StatusCode> "
1329 "{\n";
1330 out.Indent();
1331 GenerateParcelSerializeBody(out, parcel, typenames);
1332 out.Dedent();
1333 out << "}\n";
1334
1335 out << "fn read_from_parcel(&mut self, "
1336 "parcel: &binder::binder_impl::BorrowedParcel) -> std::result::Result<(), "
1337 "binder::StatusCode> {\n";
1338 out.Indent();
1339 GenerateParcelDeserializeBody(out, parcel, typenames);
1340 out.Dedent();
1341 out << "}\n";
1342
1343 out.Dedent();
1344 out << "}\n";
1345
1346 // Emit the outer (de)serialization traits
1347 out << "binder::impl_serialize_for_parcelable!(r#" << parcel->GetName();
1348 WriteParams(out, parcel);
1349 out << ");\n";
1350 out << "binder::impl_deserialize_for_parcelable!(r#" << parcel->GetName();
1351 WriteParams(out, parcel);
1352 out << ");\n";
1353 }
1354
1355 template <typename ParcelableType>
GenerateMetadataTrait(CodeWriter & out,const ParcelableType * parcel)1356 void GenerateMetadataTrait(CodeWriter& out, const ParcelableType* parcel) {
1357 out << "impl";
1358 WriteParams(out, parcel);
1359 out << " binder::binder_impl::ParcelableMetadata for r#" << parcel->GetName();
1360 WriteParams(out, parcel);
1361 out << " {\n";
1362 out.Indent();
1363
1364 out << "fn get_descriptor() -> &'static str { \"" << parcel->GetCanonicalName() << "\" }\n";
1365
1366 if (parcel->IsVintfStability()) {
1367 out << "fn get_stability(&self) -> binder::binder_impl::Stability { "
1368 "binder::binder_impl::Stability::Vintf }\n";
1369 }
1370
1371 out.Dedent();
1372 out << "}\n";
1373 }
1374
1375 template <typename ParcelableType>
GenerateRustParcel(CodeWriter * code_writer,const ParcelableType * parcel,const AidlTypenames & typenames)1376 void GenerateRustParcel(CodeWriter* code_writer, const ParcelableType* parcel,
1377 const AidlTypenames& typenames) {
1378 vector<string> derives = parcel->RustDerive();
1379
1380 // Debug is always derived because all Rust AIDL types implement it
1381 // ParcelFileDescriptor doesn't support any of the others because
1382 // it's a newtype over std::fs::File which only implements Debug
1383 derives.insert(derives.begin(), "Debug");
1384
1385 *code_writer << "#[derive(" << Join(derives, ", ") << ")]\n";
1386 GenerateParcelBody(*code_writer, parcel, typenames);
1387 GenerateConstantDeclarations(*code_writer, *parcel, typenames);
1388 GenerateParcelDefault(*code_writer, parcel, typenames);
1389 GenerateParcelableTrait(*code_writer, parcel, typenames);
1390 GenerateMetadataTrait(*code_writer, parcel);
1391 }
1392
GenerateRustEnumDeclaration(CodeWriter * code_writer,const AidlEnumDeclaration * enum_decl,const AidlTypenames & typenames)1393 void GenerateRustEnumDeclaration(CodeWriter* code_writer, const AidlEnumDeclaration* enum_decl,
1394 const AidlTypenames& typenames) {
1395 const auto& aidl_backing_type = enum_decl->GetBackingType();
1396 auto backing_type = RustNameOf(aidl_backing_type, typenames, StorageMode::VALUE, Lifetime::NONE);
1397
1398 *code_writer << "#![allow(non_upper_case_globals)]\n";
1399 *code_writer << "use binder::declare_binder_enum;\n";
1400 *code_writer << "declare_binder_enum! {\n";
1401 code_writer->Indent();
1402
1403 GenerateDeprecated(*code_writer, *enum_decl);
1404 auto alignment = cpp::AlignmentOf(aidl_backing_type, typenames);
1405 AIDL_FATAL_IF(alignment == std::nullopt, *enum_decl);
1406 // u64 is aligned to 4 bytes on x86 which may underalign the whole struct if it's the backing type
1407 // so we need to set the alignment manually as if u64 were aligned to 8 bytes.
1408 *code_writer << "#[repr(C, align(" << std::to_string(*alignment) << "))]\n";
1409 *code_writer << "r#" << enum_decl->GetName() << " : [" << backing_type << "; "
1410 << std::to_string(enum_decl->GetEnumerators().size()) << "] {\n";
1411 code_writer->Indent();
1412 for (const auto& enumerator : enum_decl->GetEnumerators()) {
1413 auto value = enumerator->GetValue()->ValueString(aidl_backing_type, ConstantValueDecorator);
1414 GenerateDeprecated(*code_writer, *enumerator);
1415 *code_writer << "r#" << enumerator->GetName() << " = " << value << ",\n";
1416 }
1417 code_writer->Dedent();
1418 *code_writer << "}\n";
1419
1420 code_writer->Dedent();
1421 *code_writer << "}\n";
1422 }
1423
GenerateClass(CodeWriter * code_writer,const AidlDefinedType & defined_type,const AidlTypenames & types,const Options & options)1424 void GenerateClass(CodeWriter* code_writer, const AidlDefinedType& defined_type,
1425 const AidlTypenames& types, const Options& options) {
1426 if (const AidlStructuredParcelable* parcelable = defined_type.AsStructuredParcelable();
1427 parcelable != nullptr) {
1428 GenerateRustParcel(code_writer, parcelable, types);
1429 } else if (const AidlEnumDeclaration* enum_decl = defined_type.AsEnumDeclaration();
1430 enum_decl != nullptr) {
1431 GenerateRustEnumDeclaration(code_writer, enum_decl, types);
1432 } else if (const AidlInterface* interface = defined_type.AsInterface(); interface != nullptr) {
1433 GenerateRustInterface(code_writer, interface, types, options);
1434 } else if (const AidlUnionDecl* union_decl = defined_type.AsUnionDeclaration();
1435 union_decl != nullptr) {
1436 GenerateRustParcel(code_writer, union_decl, types);
1437 } else {
1438 AIDL_FATAL(defined_type) << "Unrecognized type sent for Rust generation.";
1439 }
1440
1441 for (const auto& nested : defined_type.GetNestedTypes()) {
1442 (*code_writer) << "pub mod r#" << nested->GetName() << " {\n";
1443 code_writer->Indent();
1444 GenerateClass(code_writer, *nested, types, options);
1445 code_writer->Dedent();
1446 (*code_writer) << "}\n";
1447 }
1448 }
1449
GenerateRust(const string & filename,const Options & options,const AidlTypenames & types,const AidlDefinedType & defined_type,const IoDelegate & io_delegate)1450 void GenerateRust(const string& filename, const Options& options, const AidlTypenames& types,
1451 const AidlDefinedType& defined_type, const IoDelegate& io_delegate) {
1452 CodeWriterPtr code_writer = io_delegate.GetCodeWriter(filename);
1453
1454 GenerateAutoGenHeader(*code_writer, options);
1455
1456 // Forbid the use of unsafe in auto-generated code.
1457 // Unsafe code should only be allowed in libbinder_rs.
1458 *code_writer << "#![forbid(unsafe_code)]\n";
1459 // Disable rustfmt on auto-generated files, including the golden outputs
1460 *code_writer << "#![cfg_attr(rustfmt, rustfmt_skip)]\n";
1461 GenerateClass(code_writer.get(), defined_type, types, options);
1462 GenerateMangledAliases(*code_writer, defined_type);
1463
1464 AIDL_FATAL_IF(!code_writer->Close(), defined_type) << "I/O Error!";
1465 }
1466
1467 } // namespace rust
1468 } // namespace aidl
1469 } // namespace android
1470