• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2023 Google LLC.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 #include "google/protobuf/compiler/rust/oneof.h"
9 
10 #include <string>
11 
12 #include "absl/log/absl_check.h"
13 #include "absl/log/absl_log.h"
14 #include "absl/strings/str_cat.h"
15 #include "absl/strings/string_view.h"
16 #include "google/protobuf/compiler/cpp/helpers.h"
17 #include "google/protobuf/compiler/rust/accessors/accessor_case.h"
18 #include "google/protobuf/compiler/rust/context.h"
19 #include "google/protobuf/compiler/rust/naming.h"
20 #include "google/protobuf/compiler/rust/rust_field_type.h"
21 #include "google/protobuf/compiler/rust/upb_helpers.h"
22 #include "google/protobuf/descriptor.h"
23 
24 namespace google {
25 namespace protobuf {
26 namespace compiler {
27 namespace rust {
28 
29 // For each oneof we emit two Rust enums with corresponding accessors:
30 // -  An enum acting as a tagged union that has each case holds a View<> of
31 //    each of the cases. Named as the one_of name in CamelCase.
32 // -  A simple 'which oneof field is set' enum which directly maps to the
33 //    underlying enum used for the 'cases' accessor in C++ or upb. Named as the
34 //    one_of camelcase with "Case" appended.
35 //
36 // Example:
37 // For this oneof:
38 // message SomeMsg {
39 //   oneof some_oneof {
40 //     int32 field_a = 7;
41 //     SomeMsg field_b = 9;
42 //   }
43 // }
44 //
45 // This will emit as the exposed API:
46 // pub mod some_msg {
47 //   pub enum SomeOneof<'msg> {
48 //     FieldA(i32) = 7,
49 //     FieldB(View<'msg, SomeMsg>) = 9,
50 //     not_set(std::marker::PhantomData<&'msg ()>) = 0
51 //   }
52 //
53 //   #[repr(C)]
54 //   pub enum SomeOneofCase {
55 //     FieldA = 7,
56 //     FieldB = 9,
57 //     not_set = 0
58 //   }
59 // }
60 // impl SomeMsg {
61 //   pub fn some_oneof(&self) -> SomeOneof {...}
62 //   pub fn some_oneof_case(&self) -> SomeOneofCase {...}
63 // }
64 // impl SomeMsgMut {
65 //   pub fn some_oneof(&self) -> SomeOneof {...}
66 //   pub fn some_oneof_case(&self) -> SomeOneofCase {...}
67 // }
68 // impl SomeMsgView {
69 //   pub fn some_oneof(self) -> SomeOneof {...}
70 //   pub fn some_oneof_case(self) -> SomeOneofCase {...}
71 // }
72 //
73 // An additional "Case" enum which just reflects the corresponding slot numbers
74 // is emitted for usage with the FFI (exactly matching the Case struct that both
75 // cpp and upb generate).
76 //
77 // #[repr(C)] pub(super) enum SomeOneofCase {
78 //   FieldA = 7,
79 //   FieldB = 9,
80 //   not_set = 0
81 // }
82 
83 namespace {
84 // A user-friendly rust type for a view of this field with lifetime 'msg.
RsTypeNameView(Context & ctx,const FieldDescriptor & field)85 std::string RsTypeNameView(Context& ctx, const FieldDescriptor& field) {
86   if (field.options().has_ctype()) {
87     return "";  // TODO: b/308792377 - ctype fields not supported yet.
88   }
89   switch (GetRustFieldType(field.type())) {
90     case RustFieldType::INT32:
91     case RustFieldType::INT64:
92     case RustFieldType::UINT32:
93     case RustFieldType::UINT64:
94     case RustFieldType::FLOAT:
95     case RustFieldType::DOUBLE:
96     case RustFieldType::BOOL:
97       return RsTypePath(ctx, field);
98     case RustFieldType::BYTES:
99       return "&'msg [u8]";
100     case RustFieldType::STRING:
101       return "&'msg ::__pb::ProtoStr";
102     case RustFieldType::MESSAGE:
103       return absl::StrCat("::__pb::View<'msg, ", RsTypePath(ctx, field), ">");
104     case RustFieldType::ENUM:
105       return absl::StrCat("::__pb::View<'msg, ", RsTypePath(ctx, field), ">");
106   }
107 
108   ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name();
109   return "";
110 }
111 
112 }  // namespace
113 
GenerateOneofDefinition(Context & ctx,const OneofDescriptor & oneof)114 void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) {
115   ctx.Emit(
116       {
117           {"view_enum_name", OneofViewEnumRsName(oneof)},
118           {"view_fields",
119            [&] {
120              for (int i = 0; i < oneof.field_count(); ++i) {
121                auto& field = *oneof.field(i);
122                std::string rs_type = RsTypeNameView(ctx, field);
123                if (rs_type.empty()) {
124                  continue;
125                }
126                ctx.Emit({{"name", OneofCaseRsName(field)},
127                          {"type", rs_type},
128                          {"number", std::to_string(field.number())}},
129                         R"rs($name$($type$) = $number$,
130                 )rs");
131              }
132            }},
133       },
134       // TODO: Revisit if isize is the optimal repr for this enum.
135       // Note: This enum deliberately has a 'msg lifetime associated with it
136       // even if all fields were scalars; we could conditionally exclude the
137       // lifetime under that case, but it would mean changing the .proto file
138       // to add an additional string or message-typed field to the oneof would
139       // be a more breaking change than it needs to be.
140       R"rs(
141       #[non_exhaustive]
142       #[derive(Debug, Clone, Copy)]
143       #[allow(dead_code)]
144       #[repr(isize)]
145       pub enum $view_enum_name$<'msg> {
146         $view_fields$
147 
148         #[allow(non_camel_case_types)]
149         not_set(std::marker::PhantomData<&'msg ()>) = 0
150       }
151       )rs");
152 
153   // Note: This enum is used as the Thunk return type for getting which case is
154   // used: it exactly matches the generate case enum that both cpp and upb use.
155   ctx.Emit({{"case_enum_name", OneofCaseEnumRsName(oneof)},
156             {"cases",
157              [&] {
158                for (int i = 0; i < oneof.field_count(); ++i) {
159                  auto& field = *oneof.field(i);
160                  ctx.Emit({{"name", OneofCaseRsName(field)},
161                            {"number", std::to_string(field.number())}},
162                           R"rs($name$ = $number$,
163                           )rs");
164                }
165              }},
166             {"try_from_cases",
167              [&] {
168                for (int i = 0; i < oneof.field_count(); ++i) {
169                  auto& field = *oneof.field(i);
170                  ctx.Emit({{"name", OneofCaseRsName(field)},
171                            {"number", std::to_string(field.number())}},
172                           R"rs($number$ => Some($case_enum_name$::$name$),
173                           )rs");
174                }
175              }}},
176            R"rs(
177       #[repr(C)]
178       #[derive(Debug, Copy, Clone, PartialEq, Eq)]
179       #[allow(dead_code)]
180       pub enum $case_enum_name$ {
181         $cases$
182 
183         #[allow(non_camel_case_types)]
184         not_set = 0
185       }
186 
187       impl $case_enum_name$ {
188         //~ This try_from is not a TryFrom impl so that it isn't
189         //~ committed to as part of our public api.
190         #[allow(dead_code)]
191         pub(crate) fn try_from(v: u32) -> $Option$<$case_enum_name$> {
192           match v {
193             0 => Some($case_enum_name$::not_set),
194             $try_from_cases$
195             _ => None
196           }
197         }
198       }
199 
200       )rs");
201 }
202 
GenerateOneofAccessors(Context & ctx,const OneofDescriptor & oneof,AccessorCase accessor_case)203 void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof,
204                             AccessorCase accessor_case) {
205   ctx.Emit(
206       {{"oneof_name", RsSafeName(oneof.name())},
207        {"view_lifetime", ViewLifetime(accessor_case)},
208        {"self", ViewReceiver(accessor_case)},
209        {"oneof_enum_module",
210         absl::StrCat("crate::", RustModuleForContainingType(
211                                     ctx, oneof.containing_type()))},
212        {"view_enum_name", OneofViewEnumRsName(oneof)},
213        {"case_enum_name", OneofCaseEnumRsName(oneof)},
214        {"view_cases",
215         [&] {
216           for (int i = 0; i < oneof.field_count(); ++i) {
217             auto& field = *oneof.field(i);
218             std::string rs_type = RsTypeNameView(ctx, field);
219             if (rs_type.empty()) {
220               continue;
221             }
222             std::string field_name = FieldNameWithCollisionAvoidance(field);
223             ctx.Emit(
224                 {
225                     {"case", OneofCaseRsName(field)},
226                     {"rs_getter", RsSafeName(field_name)},
227                     {"type", rs_type},
228                 },
229                 R"rs(
230                 $oneof_enum_module$$case_enum_name$::$case$ =>
231                     $oneof_enum_module$$view_enum_name$::$case$(self.$rs_getter$()),
232                 )rs");
233           }
234         }},
235        {"oneof_case_body",
236         [&] {
237           if (ctx.is_cpp()) {
238             ctx.Emit({{"case_thunk", ThunkName(ctx, oneof, "case")}},
239                      "unsafe { $case_thunk$(self.raw_msg()) }");
240           } else {
241             ctx.Emit(
242                 // The field index for an arbitrary field that in the oneof.
243                 {{"upb_mt_field_index",
244                   UpbMiniTableFieldIndex(*oneof.field(0))}},
245                 R"rs(
246                 let field_num = unsafe {
247                   let f = $pbr$::upb_MiniTable_GetFieldByIndex(
248                       <Self as $pbr$::AssociatedMiniTable>::mini_table(),
249                       $upb_mt_field_index$);
250                   $pbr$::upb_Message_WhichOneofFieldNumber(
251                         self.raw_msg(), f)
252                 };
253                 unsafe {
254                   $oneof_enum_module$$case_enum_name$::try_from(field_num).unwrap_unchecked()
255                 }
256               )rs");
257           }
258         }}},
259       R"rs(
260         pub fn $oneof_name$($self$) -> $oneof_enum_module$$view_enum_name$<$view_lifetime$> {
261           match $self$.$oneof_name$_case() {
262             $view_cases$
263             _ => $oneof_enum_module$$view_enum_name$::not_set(std::marker::PhantomData)
264           }
265         }
266 
267         pub fn $oneof_name$_case($self$) -> $oneof_enum_module$$case_enum_name$ {
268           $oneof_case_body$
269         }
270       )rs");
271 }
272 
GenerateOneofExternC(Context & ctx,const OneofDescriptor & oneof)273 void GenerateOneofExternC(Context& ctx, const OneofDescriptor& oneof) {
274   ABSL_CHECK(ctx.is_cpp());
275 
276   ctx.Emit(
277       {
278           {"oneof_enum_module",
279            absl::StrCat("crate::", RustModuleForContainingType(
280                                        ctx, oneof.containing_type()))},
281           {"case_enum_rs_name", OneofCaseEnumRsName(oneof)},
282           {"case_thunk", ThunkName(ctx, oneof, "case")},
283       },
284       R"rs(
285         fn $case_thunk$(raw_msg: $pbr$::RawMessage) -> $oneof_enum_module$$case_enum_rs_name$;
286       )rs");
287 }
288 
GenerateOneofThunkCc(Context & ctx,const OneofDescriptor & oneof)289 void GenerateOneofThunkCc(Context& ctx, const OneofDescriptor& oneof) {
290   ABSL_CHECK(ctx.is_cpp());
291 
292   ctx.Emit(
293       {
294           {"oneof_name", oneof.name()},
295           {"case_enum_name", OneofCaseEnumRsName(oneof)},
296           {"case_thunk", ThunkName(ctx, oneof, "case")},
297           {"QualifiedMsg", cpp::QualifiedClassName(oneof.containing_type())},
298       },
299       R"cc(
300         $QualifiedMsg$::$case_enum_name$ $case_thunk$($QualifiedMsg$* msg) {
301           return msg->$oneof_name$_case();
302         }
303       )cc");
304 }
305 
306 }  // namespace rust
307 }  // namespace compiler
308 }  // namespace protobuf
309 }  // namespace google
310