1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2023 Google LLC. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 #include "google/protobuf/compiler/rust/enum.h"
9
10 #include <cstddef>
11 #include <cstdint>
12 #include <initializer_list>
13 #include <string>
14 #include <utility>
15 #include <vector>
16
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/log/absl_check.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "google/protobuf/compiler/cpp/names.h"
25 #include "google/protobuf/compiler/rust/context.h"
26 #include "google/protobuf/compiler/rust/naming.h"
27 #include "google/protobuf/descriptor.h"
28
29 namespace google {
30 namespace protobuf {
31 namespace compiler {
32 namespace rust {
33
34 namespace {
35 // Constructs input for `EnumValues` from an enum descriptor.
EnumValuesInput(const EnumDescriptor & desc)36 std::vector<std::pair<absl::string_view, int32_t>> EnumValuesInput(
37 const EnumDescriptor& desc) {
38 std::vector<std::pair<absl::string_view, int32_t>> result;
39 result.reserve(static_cast<size_t>(desc.value_count()));
40
41 for (int i = 0; i < desc.value_count(); ++i) {
42 result.emplace_back(desc.value(i)->name(), desc.value(i)->number());
43 }
44
45 return result;
46 }
47
EnumProxiedInMapValue(Context & ctx,const EnumDescriptor & desc)48 void EnumProxiedInMapValue(Context& ctx, const EnumDescriptor& desc) {
49 switch (ctx.opts().kernel) {
50 case Kernel::kCpp:
51 for (const auto& t : kMapKeyTypes) {
52 ctx.Emit(
53 {{"map_new_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "new")},
54 {"map_free_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "free")},
55 {"map_clear_thunk",
56 RawMapThunk(ctx, desc, t.thunk_ident, "clear")},
57 {"map_size_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "size")},
58 {"map_insert_thunk",
59 RawMapThunk(ctx, desc, t.thunk_ident, "insert")},
60 {"map_get_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "get")},
61 {"map_remove_thunk",
62 RawMapThunk(ctx, desc, t.thunk_ident, "remove")},
63 {"map_iter_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "iter")},
64 {"map_iter_get_thunk",
65 RawMapThunk(ctx, desc, t.thunk_ident, "iter_get")},
66 {"to_ffi_key_expr", t.rs_to_ffi_key_expr},
67 io::Printer::Sub("ffi_key_t", [&] { ctx.Emit(t.rs_ffi_key_t); })
68 .WithSuffix(""),
69 io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); })
70 .WithSuffix(""),
71 io::Printer::Sub("from_ffi_key_expr",
72 [&] { ctx.Emit(t.rs_from_ffi_key_expr); })
73 .WithSuffix("")},
74 R"rs(
75 impl $pb$::ProxiedInMapValue<$key_t$> for $name$ {
76 fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> {
77 unsafe {
78 $pb$::Map::from_inner(
79 $pbi$::Private,
80 $pbr$::InnerMap::new($pbr$::$map_new_thunk$())
81 )
82 }
83 }
84
85 unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) {
86 unsafe { $pbr$::$map_free_thunk$(map.as_raw($pbi$::Private)); }
87 }
88
89 fn map_clear(mut map: $pb$::MapMut<$key_t$, Self>) {
90 unsafe { $pbr$::$map_clear_thunk$(map.as_raw($pbi$::Private)); }
91 }
92
93 fn map_len(map: $pb$::MapView<$key_t$, Self>) -> usize {
94 unsafe { $pbr$::$map_size_thunk$(map.as_raw($pbi$::Private)) }
95 }
96
97 fn map_insert(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied<Self>) -> bool {
98 unsafe { $pbr$::$map_insert_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.into_proxied($pbi$::Private).0) }
99 }
100
101 fn map_get<'a>(map: $pb$::MapView<'a, $key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> $Option$<$pb$::View<'a, Self>> {
102 let key = $to_ffi_key_expr$;
103 let mut value = $std$::mem::MaybeUninit::uninit();
104 let found = unsafe { $pbr$::$map_get_thunk$(map.as_raw($pbi$::Private), key, value.as_mut_ptr()) };
105 if !found {
106 return None;
107 }
108 Some(unsafe { $name$(value.assume_init()) })
109 }
110
111 fn map_remove(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> bool {
112 let mut value = $std$::mem::MaybeUninit::uninit();
113 unsafe { $pbr$::$map_remove_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.as_mut_ptr()) }
114 }
115
116 fn map_iter(map: $pb$::MapView<$key_t$, Self>) -> $pb$::MapIter<$key_t$, Self> {
117 // SAFETY:
118 // - The backing map for `map.as_raw` is valid for at least '_.
119 // - A View that is live for '_ guarantees the backing map is unmodified for '_.
120 // - The `iter` function produces an iterator that is valid for the key
121 // and value types, and live for at least '_.
122 unsafe {
123 $pb$::MapIter::from_raw(
124 $pbi$::Private,
125 $pbr$::$map_iter_thunk$(map.as_raw($pbi$::Private))
126 )
127 }
128 }
129
130 fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> $Option$<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> {
131 // SAFETY:
132 // - The `MapIter` API forbids the backing map from being mutated for 'a,
133 // and guarantees that it's the correct key and value types.
134 // - The thunk is safe to call as long as the iterator isn't at the end.
135 // - The thunk always writes to key and value fields and does not read.
136 // - The thunk does not increment the iterator.
137 unsafe {
138 iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>(
139 |iter, key, value| { $pbr$::$map_iter_get_thunk$(iter, key, value) },
140 |ffi_key| $from_ffi_key_expr$,
141 |value| $name$(value),
142 )
143 }
144 }
145 }
146 )rs");
147 }
148 return;
149 case Kernel::kUpb:
150 ctx.Emit(R"rs(
151 impl $pbr$::UpbTypeConversions for $name$ {
152 fn upb_type() -> $pbr$::CType {
153 $pbr$::CType::Enum
154 }
155
156 fn to_message_value(
157 val: $pb$::View<'_, Self>) -> $pbr$::upb_MessageValue {
158 $pbr$::upb_MessageValue { int32_val: val.0 }
159 }
160
161 unsafe fn into_message_value_fuse_if_required(
162 raw_parent_arena: $pbr$::RawArena,
163 val: Self) -> $pbr$::upb_MessageValue {
164 $pbr$::upb_MessageValue { int32_val: val.0 }
165 }
166
167 unsafe fn from_message_value<'msg>(val: $pbr$::upb_MessageValue)
168 -> $pb$::View<'msg, Self> {
169 $name$(unsafe { val.int32_val })
170 }
171 }
172 )rs");
173 return;
174 }
175 }
176
177 } // namespace
178
EnumValues(absl::string_view enum_name,absl::Span<const std::pair<absl::string_view,int32_t>> values)179 std::vector<RustEnumValue> EnumValues(
180 absl::string_view enum_name,
181 absl::Span<const std::pair<absl::string_view, int32_t>> values) {
182 MultiCasePrefixStripper stripper(enum_name);
183
184 absl::flat_hash_set<std::string> seen_by_name;
185 absl::flat_hash_map<int32_t, RustEnumValue*> seen_by_number;
186 std::vector<RustEnumValue> result;
187 // The below code depends on pointer stability of elements in `result`;
188 // this reserve must not be too low.
189 result.reserve(values.size());
190 seen_by_name.reserve(values.size());
191 seen_by_number.reserve(values.size());
192
193 for (const auto& name_and_number : values) {
194 int32_t number = name_and_number.second;
195 std::string rust_value_name =
196 EnumValueRsName(stripper, name_and_number.first);
197
198 if (seen_by_name.contains(rust_value_name)) {
199 // Don't add an alias with the same normalized name.
200 continue;
201 }
202
203 auto it_and_inserted = seen_by_number.try_emplace(number);
204 if (it_and_inserted.second) {
205 // This is the first value with this number; this name is canonical.
206 result.push_back(RustEnumValue{rust_value_name, number});
207 it_and_inserted.first->second = &result.back();
208 } else {
209 // This number has been seen before; this name is an alias.
210 it_and_inserted.first->second->aliases.push_back(rust_value_name);
211 }
212
213 seen_by_name.insert(std::move(rust_value_name));
214 }
215 return result;
216 }
217
GenerateEnumDefinition(Context & ctx,const EnumDescriptor & desc)218 void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) {
219 std::string name = EnumRsName(desc);
220 ABSL_CHECK(desc.value_count() > 0);
221 std::vector<RustEnumValue> values =
222 EnumValues(desc.name(), EnumValuesInput(desc));
223 ABSL_CHECK(!values.empty());
224
225 ctx.Emit(
226 {
227 {"name", name},
228 {"variants",
229 [&] {
230 for (const auto& value : values) {
231 std::string number_str = absl::StrCat(value.number);
232 // TODO: Replace with open enum variants when stable
233 ctx.Emit({{"variant_name", value.name}, {"number", number_str}},
234 R"rs(
235 pub const $variant_name$: $name$ = $name$($number$);
236 )rs");
237 for (const auto& alias : value.aliases) {
238 ctx.Emit({{"alias_name", alias}, {"number", number_str}},
239 R"rs(
240 pub const $alias_name$: $name$ = $name$($number$);
241 )rs");
242 }
243 }
244 }},
245 // The default value of an enum is the first listed value.
246 // The compiler checks that this is equal to 0 for open enums.
247 {"default_int_value", absl::StrCat(desc.value(0)->number())},
248 {"known_values_pattern",
249 // TODO: Check validity in UPB/C++.
250 absl::StrJoin(values, "|",
251 [](std::string* o, const RustEnumValue& val) {
252 absl::StrAppend(o, val.number);
253 })},
254 {"impl_from_i32",
255 [&] {
256 if (desc.is_closed()) {
257 ctx.Emit(R"rs(
258 impl $std$::convert::TryFrom<i32> for $name$ {
259 type Error = $pb$::UnknownEnumValue<Self>;
260
261 fn try_from(val: i32) -> $Result$<$name$, Self::Error> {
262 if <Self as $pbi$::Enum>::is_known(val) {
263 Ok(Self(val))
264 } else {
265 Err($pb$::UnknownEnumValue::new($pbi$::Private, val))
266 }
267 }
268 }
269 )rs");
270 } else {
271 ctx.Emit(R"rs(
272 impl $std$::convert::From<i32> for $name$ {
273 fn from(val: i32) -> $name$ {
274 Self(val)
275 }
276 }
277 )rs");
278 }
279 }},
280 {"impl_proxied_in_map", [&] { EnumProxiedInMapValue(ctx, desc); }},
281 },
282 R"rs(
283 #[repr(transparent)]
284 #[derive(Clone, Copy, PartialEq, Eq)]
285 pub struct $name$(i32);
286
287 #[allow(non_upper_case_globals)]
288 impl $name$ {
289 $variants$
290 }
291
292 impl $std$::convert::From<$name$> for i32 {
293 fn from(val: $name$) -> i32 {
294 val.0
295 }
296 }
297
298 $impl_from_i32$
299
300 impl $std$::default::Default for $name$ {
301 fn default() -> Self {
302 Self($default_int_value$)
303 }
304 }
305
306 impl $std$::fmt::Debug for $name$ {
307 fn fmt(&self, f: &mut $std$::fmt::Formatter<'_>) -> $std$::fmt::Result {
308 f.debug_tuple(stringify!($name$)).field(&self.0).finish()
309 }
310 }
311
312 impl $pb$::IntoProxied<i32> for $name$ {
313 fn into_proxied(self, _: $pbi$::Private) -> i32 {
314 self.0
315 }
316 }
317
318 impl $pbi$::SealedInternal for $name$ {}
319
320 impl $pb$::Proxied for $name$ {
321 type View<'a> = $name$;
322 }
323
324 impl $pb$::Proxy<'_> for $name$ {}
325 impl $pb$::ViewProxy<'_> for $name$ {}
326
327 impl $pb$::AsView for $name$ {
328 type Proxied = $name$;
329
330 fn as_view(&self) -> $name$ {
331 *self
332 }
333 }
334
335 impl<'msg> $pb$::IntoView<'msg> for $name$ {
336 fn into_view<'shorter>(self) -> $name$ where 'msg: 'shorter {
337 self
338 }
339 }
340
341 unsafe impl $pb$::ProxiedInRepeated for $name$ {
342 fn repeated_new(_private: $pbi$::Private) -> $pb$::Repeated<Self> {
343 $pbr$::new_enum_repeated()
344 }
345
346 unsafe fn repeated_free(_private: $pbi$::Private, f: &mut $pb$::Repeated<Self>) {
347 $pbr$::free_enum_repeated(f)
348 }
349
350 fn repeated_len(r: $pb$::View<$pb$::Repeated<Self>>) -> usize {
351 $pbr$::cast_enum_repeated_view(r).len()
352 }
353
354 fn repeated_push(r: $pb$::Mut<$pb$::Repeated<Self>>, val: impl $pb$::IntoProxied<$name$>) {
355 $pbr$::cast_enum_repeated_mut(r).push(val.into_proxied($pbi$::Private))
356 }
357
358 fn repeated_clear(r: $pb$::Mut<$pb$::Repeated<Self>>) {
359 $pbr$::cast_enum_repeated_mut(r).clear()
360 }
361
362 unsafe fn repeated_get_unchecked(
363 r: $pb$::View<$pb$::Repeated<Self>>,
364 index: usize,
365 ) -> $pb$::View<$name$> {
366 // SAFETY: In-bounds as promised by the caller.
367 unsafe {
368 $pbr$::cast_enum_repeated_view(r)
369 .get_unchecked(index)
370 .try_into()
371 .unwrap_unchecked()
372 }
373 }
374
375 unsafe fn repeated_set_unchecked(
376 r: $pb$::Mut<$pb$::Repeated<Self>>,
377 index: usize,
378 val: impl $pb$::IntoProxied<$name$>,
379 ) {
380 // SAFETY: In-bounds as promised by the caller.
381 unsafe {
382 $pbr$::cast_enum_repeated_mut(r)
383 .set_unchecked(index, val.into_proxied($pbi$::Private))
384 }
385 }
386
387 fn repeated_copy_from(
388 src: $pb$::View<$pb$::Repeated<Self>>,
389 dest: $pb$::Mut<$pb$::Repeated<Self>>,
390 ) {
391 $pbr$::cast_enum_repeated_mut(dest)
392 .copy_from($pbr$::cast_enum_repeated_view(src))
393 }
394
395 fn repeated_reserve(
396 r: $pb$::Mut<$pb$::Repeated<Self>>,
397 additional: usize,
398 ) {
399 // SAFETY:
400 // - `f.as_raw()` is valid.
401 $pbr$::reserve_enum_repeated_mut(r, additional);
402 }
403 }
404
405 // SAFETY: this is an enum type
406 unsafe impl $pbi$::Enum for $name$ {
407 const NAME: &'static str = "$name$";
408
409 fn is_known(value: i32) -> bool {
410 matches!(value, $known_values_pattern$)
411 }
412 }
413
414 $impl_proxied_in_map$
415 )rs");
416 }
417
418 } // namespace rust
419 } // namespace compiler
420 } // namespace protobuf
421 } // namespace google
422