1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2023 Google LLC. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 #include "google/protobuf/compiler/rust/oneof.h"
9
10 #include <string>
11
12 #include "absl/log/absl_check.h"
13 #include "absl/log/absl_log.h"
14 #include "absl/strings/str_cat.h"
15 #include "absl/strings/string_view.h"
16 #include "google/protobuf/compiler/cpp/helpers.h"
17 #include "google/protobuf/compiler/rust/accessors/accessor_case.h"
18 #include "google/protobuf/compiler/rust/context.h"
19 #include "google/protobuf/compiler/rust/naming.h"
20 #include "google/protobuf/compiler/rust/rust_field_type.h"
21 #include "google/protobuf/compiler/rust/upb_helpers.h"
22 #include "google/protobuf/descriptor.h"
23
24 namespace google {
25 namespace protobuf {
26 namespace compiler {
27 namespace rust {
28
29 // For each oneof we emit two Rust enums with corresponding accessors:
30 // - An enum acting as a tagged union that has each case holds a View<> of
31 // each of the cases. Named as the one_of name in CamelCase.
32 // - A simple 'which oneof field is set' enum which directly maps to the
33 // underlying enum used for the 'cases' accessor in C++ or upb. Named as the
34 // one_of camelcase with "Case" appended.
35 //
36 // Example:
37 // For this oneof:
38 // message SomeMsg {
39 // oneof some_oneof {
40 // int32 field_a = 7;
41 // SomeMsg field_b = 9;
42 // }
43 // }
44 //
45 // This will emit as the exposed API:
46 // pub mod some_msg {
47 // pub enum SomeOneof<'msg> {
48 // FieldA(i32) = 7,
49 // FieldB(View<'msg, SomeMsg>) = 9,
50 // not_set(std::marker::PhantomData<&'msg ()>) = 0
51 // }
52 //
53 // #[repr(C)]
54 // pub enum SomeOneofCase {
55 // FieldA = 7,
56 // FieldB = 9,
57 // not_set = 0
58 // }
59 // }
60 // impl SomeMsg {
61 // pub fn some_oneof(&self) -> SomeOneof {...}
62 // pub fn some_oneof_case(&self) -> SomeOneofCase {...}
63 // }
64 // impl SomeMsgMut {
65 // pub fn some_oneof(&self) -> SomeOneof {...}
66 // pub fn some_oneof_case(&self) -> SomeOneofCase {...}
67 // }
68 // impl SomeMsgView {
69 // pub fn some_oneof(self) -> SomeOneof {...}
70 // pub fn some_oneof_case(self) -> SomeOneofCase {...}
71 // }
72 //
73 // An additional "Case" enum which just reflects the corresponding slot numbers
74 // is emitted for usage with the FFI (exactly matching the Case struct that both
75 // cpp and upb generate).
76 //
77 // #[repr(C)] pub(super) enum SomeOneofCase {
78 // FieldA = 7,
79 // FieldB = 9,
80 // not_set = 0
81 // }
82
83 namespace {
84 // A user-friendly rust type for a view of this field with lifetime 'msg.
RsTypeNameView(Context & ctx,const FieldDescriptor & field)85 std::string RsTypeNameView(Context& ctx, const FieldDescriptor& field) {
86 if (field.options().has_ctype()) {
87 return ""; // TODO: b/308792377 - ctype fields not supported yet.
88 }
89 switch (GetRustFieldType(field.type())) {
90 case RustFieldType::INT32:
91 case RustFieldType::INT64:
92 case RustFieldType::UINT32:
93 case RustFieldType::UINT64:
94 case RustFieldType::FLOAT:
95 case RustFieldType::DOUBLE:
96 case RustFieldType::BOOL:
97 return RsTypePath(ctx, field);
98 case RustFieldType::BYTES:
99 return "&'msg [u8]";
100 case RustFieldType::STRING:
101 return "&'msg ::__pb::ProtoStr";
102 case RustFieldType::MESSAGE:
103 return absl::StrCat("::__pb::View<'msg, ", RsTypePath(ctx, field), ">");
104 case RustFieldType::ENUM:
105 return absl::StrCat("::__pb::View<'msg, ", RsTypePath(ctx, field), ">");
106 }
107
108 ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name();
109 return "";
110 }
111
112 } // namespace
113
GenerateOneofDefinition(Context & ctx,const OneofDescriptor & oneof)114 void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) {
115 ctx.Emit(
116 {
117 {"view_enum_name", OneofViewEnumRsName(oneof)},
118 {"view_fields",
119 [&] {
120 for (int i = 0; i < oneof.field_count(); ++i) {
121 auto& field = *oneof.field(i);
122 std::string rs_type = RsTypeNameView(ctx, field);
123 if (rs_type.empty()) {
124 continue;
125 }
126 ctx.Emit({{"name", OneofCaseRsName(field)},
127 {"type", rs_type},
128 {"number", std::to_string(field.number())}},
129 R"rs($name$($type$) = $number$,
130 )rs");
131 }
132 }},
133 },
134 // TODO: Revisit if isize is the optimal repr for this enum.
135 // Note: This enum deliberately has a 'msg lifetime associated with it
136 // even if all fields were scalars; we could conditionally exclude the
137 // lifetime under that case, but it would mean changing the .proto file
138 // to add an additional string or message-typed field to the oneof would
139 // be a more breaking change than it needs to be.
140 R"rs(
141 #[non_exhaustive]
142 #[derive(Debug, Clone, Copy)]
143 #[allow(dead_code)]
144 #[repr(isize)]
145 pub enum $view_enum_name$<'msg> {
146 $view_fields$
147
148 #[allow(non_camel_case_types)]
149 not_set(std::marker::PhantomData<&'msg ()>) = 0
150 }
151 )rs");
152
153 // Note: This enum is used as the Thunk return type for getting which case is
154 // used: it exactly matches the generate case enum that both cpp and upb use.
155 ctx.Emit({{"case_enum_name", OneofCaseEnumRsName(oneof)},
156 {"cases",
157 [&] {
158 for (int i = 0; i < oneof.field_count(); ++i) {
159 auto& field = *oneof.field(i);
160 ctx.Emit({{"name", OneofCaseRsName(field)},
161 {"number", std::to_string(field.number())}},
162 R"rs($name$ = $number$,
163 )rs");
164 }
165 }},
166 {"try_from_cases",
167 [&] {
168 for (int i = 0; i < oneof.field_count(); ++i) {
169 auto& field = *oneof.field(i);
170 ctx.Emit({{"name", OneofCaseRsName(field)},
171 {"number", std::to_string(field.number())}},
172 R"rs($number$ => Some($case_enum_name$::$name$),
173 )rs");
174 }
175 }}},
176 R"rs(
177 #[repr(C)]
178 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
179 #[allow(dead_code)]
180 pub enum $case_enum_name$ {
181 $cases$
182
183 #[allow(non_camel_case_types)]
184 not_set = 0
185 }
186
187 impl $case_enum_name$ {
188 //~ This try_from is not a TryFrom impl so that it isn't
189 //~ committed to as part of our public api.
190 #[allow(dead_code)]
191 pub(crate) fn try_from(v: u32) -> $Option$<$case_enum_name$> {
192 match v {
193 0 => Some($case_enum_name$::not_set),
194 $try_from_cases$
195 _ => None
196 }
197 }
198 }
199
200 )rs");
201 }
202
GenerateOneofAccessors(Context & ctx,const OneofDescriptor & oneof,AccessorCase accessor_case)203 void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof,
204 AccessorCase accessor_case) {
205 ctx.Emit(
206 {{"oneof_name", RsSafeName(oneof.name())},
207 {"view_lifetime", ViewLifetime(accessor_case)},
208 {"self", ViewReceiver(accessor_case)},
209 {"oneof_enum_module",
210 absl::StrCat("crate::", RustModuleForContainingType(
211 ctx, oneof.containing_type()))},
212 {"view_enum_name", OneofViewEnumRsName(oneof)},
213 {"case_enum_name", OneofCaseEnumRsName(oneof)},
214 {"view_cases",
215 [&] {
216 for (int i = 0; i < oneof.field_count(); ++i) {
217 auto& field = *oneof.field(i);
218 std::string rs_type = RsTypeNameView(ctx, field);
219 if (rs_type.empty()) {
220 continue;
221 }
222 std::string field_name = FieldNameWithCollisionAvoidance(field);
223 ctx.Emit(
224 {
225 {"case", OneofCaseRsName(field)},
226 {"rs_getter", RsSafeName(field_name)},
227 {"type", rs_type},
228 },
229 R"rs(
230 $oneof_enum_module$$case_enum_name$::$case$ =>
231 $oneof_enum_module$$view_enum_name$::$case$(self.$rs_getter$()),
232 )rs");
233 }
234 }},
235 {"oneof_case_body",
236 [&] {
237 if (ctx.is_cpp()) {
238 ctx.Emit({{"case_thunk", ThunkName(ctx, oneof, "case")}},
239 "unsafe { $case_thunk$(self.raw_msg()) }");
240 } else {
241 ctx.Emit(
242 // The field index for an arbitrary field that in the oneof.
243 {{"upb_mt_field_index",
244 UpbMiniTableFieldIndex(*oneof.field(0))}},
245 R"rs(
246 let field_num = unsafe {
247 let f = $pbr$::upb_MiniTable_GetFieldByIndex(
248 <Self as $pbr$::AssociatedMiniTable>::mini_table(),
249 $upb_mt_field_index$);
250 $pbr$::upb_Message_WhichOneofFieldNumber(
251 self.raw_msg(), f)
252 };
253 unsafe {
254 $oneof_enum_module$$case_enum_name$::try_from(field_num).unwrap_unchecked()
255 }
256 )rs");
257 }
258 }}},
259 R"rs(
260 pub fn $oneof_name$($self$) -> $oneof_enum_module$$view_enum_name$<$view_lifetime$> {
261 match $self$.$oneof_name$_case() {
262 $view_cases$
263 _ => $oneof_enum_module$$view_enum_name$::not_set(std::marker::PhantomData)
264 }
265 }
266
267 pub fn $oneof_name$_case($self$) -> $oneof_enum_module$$case_enum_name$ {
268 $oneof_case_body$
269 }
270 )rs");
271 }
272
GenerateOneofExternC(Context & ctx,const OneofDescriptor & oneof)273 void GenerateOneofExternC(Context& ctx, const OneofDescriptor& oneof) {
274 ABSL_CHECK(ctx.is_cpp());
275
276 ctx.Emit(
277 {
278 {"oneof_enum_module",
279 absl::StrCat("crate::", RustModuleForContainingType(
280 ctx, oneof.containing_type()))},
281 {"case_enum_rs_name", OneofCaseEnumRsName(oneof)},
282 {"case_thunk", ThunkName(ctx, oneof, "case")},
283 },
284 R"rs(
285 fn $case_thunk$(raw_msg: $pbr$::RawMessage) -> $oneof_enum_module$$case_enum_rs_name$;
286 )rs");
287 }
288
GenerateOneofThunkCc(Context & ctx,const OneofDescriptor & oneof)289 void GenerateOneofThunkCc(Context& ctx, const OneofDescriptor& oneof) {
290 ABSL_CHECK(ctx.is_cpp());
291
292 ctx.Emit(
293 {
294 {"oneof_name", oneof.name()},
295 {"case_enum_name", OneofCaseEnumRsName(oneof)},
296 {"case_thunk", ThunkName(ctx, oneof, "case")},
297 {"QualifiedMsg", cpp::QualifiedClassName(oneof.containing_type())},
298 },
299 R"cc(
300 $QualifiedMsg$::$case_enum_name$ $case_thunk$($QualifiedMsg$* msg) {
301 return msg->$oneof_name$_case();
302 }
303 )cc");
304 }
305
306 } // namespace rust
307 } // namespace compiler
308 } // namespace protobuf
309 } // namespace google
310