• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2023 Google LLC.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 #include "google/protobuf/compiler/rust/enum.h"
9 
10 #include <cstddef>
11 #include <cstdint>
12 #include <initializer_list>
13 #include <string>
14 #include <utility>
15 #include <vector>
16 
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/log/absl_check.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "google/protobuf/compiler/cpp/names.h"
25 #include "google/protobuf/compiler/rust/context.h"
26 #include "google/protobuf/compiler/rust/naming.h"
27 #include "google/protobuf/descriptor.h"
28 
29 namespace google {
30 namespace protobuf {
31 namespace compiler {
32 namespace rust {
33 
34 namespace {
35 // Constructs input for `EnumValues` from an enum descriptor.
EnumValuesInput(const EnumDescriptor & desc)36 std::vector<std::pair<absl::string_view, int32_t>> EnumValuesInput(
37     const EnumDescriptor& desc) {
38   std::vector<std::pair<absl::string_view, int32_t>> result;
39   result.reserve(static_cast<size_t>(desc.value_count()));
40 
41   for (int i = 0; i < desc.value_count(); ++i) {
42     result.emplace_back(desc.value(i)->name(), desc.value(i)->number());
43   }
44 
45   return result;
46 }
47 
EnumProxiedInMapValue(Context & ctx,const EnumDescriptor & desc)48 void EnumProxiedInMapValue(Context& ctx, const EnumDescriptor& desc) {
49   switch (ctx.opts().kernel) {
50     case Kernel::kCpp:
51       for (const auto& t : kMapKeyTypes) {
52         ctx.Emit(
53             {{"map_new_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "new")},
54              {"map_free_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "free")},
55              {"map_clear_thunk",
56               RawMapThunk(ctx, desc, t.thunk_ident, "clear")},
57              {"map_size_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "size")},
58              {"map_insert_thunk",
59               RawMapThunk(ctx, desc, t.thunk_ident, "insert")},
60              {"map_get_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "get")},
61              {"map_remove_thunk",
62               RawMapThunk(ctx, desc, t.thunk_ident, "remove")},
63              {"map_iter_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "iter")},
64              {"map_iter_get_thunk",
65               RawMapThunk(ctx, desc, t.thunk_ident, "iter_get")},
66              {"to_ffi_key_expr", t.rs_to_ffi_key_expr},
67              io::Printer::Sub("ffi_key_t", [&] { ctx.Emit(t.rs_ffi_key_t); })
68                  .WithSuffix(""),
69              io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); })
70                  .WithSuffix(""),
71              io::Printer::Sub("from_ffi_key_expr",
72                               [&] { ctx.Emit(t.rs_from_ffi_key_expr); })
73                  .WithSuffix("")},
74             R"rs(
75       impl $pb$::ProxiedInMapValue<$key_t$> for $name$ {
76         fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> {
77             unsafe {
78                 $pb$::Map::from_inner(
79                     $pbi$::Private,
80                     $pbr$::InnerMap::new($pbr$::$map_new_thunk$())
81                 )
82             }
83         }
84 
85         unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) {
86             unsafe { $pbr$::$map_free_thunk$(map.as_raw($pbi$::Private)); }
87         }
88 
89         fn map_clear(mut map: $pb$::MapMut<$key_t$, Self>) {
90             unsafe { $pbr$::$map_clear_thunk$(map.as_raw($pbi$::Private)); }
91         }
92 
93         fn map_len(map: $pb$::MapView<$key_t$, Self>) -> usize {
94             unsafe { $pbr$::$map_size_thunk$(map.as_raw($pbi$::Private)) }
95         }
96 
97         fn map_insert(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied<Self>) -> bool {
98             unsafe { $pbr$::$map_insert_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.into_proxied($pbi$::Private).0) }
99         }
100 
101         fn map_get<'a>(map: $pb$::MapView<'a, $key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> $Option$<$pb$::View<'a, Self>> {
102             let key = $to_ffi_key_expr$;
103             let mut value = $std$::mem::MaybeUninit::uninit();
104             let found = unsafe { $pbr$::$map_get_thunk$(map.as_raw($pbi$::Private), key, value.as_mut_ptr()) };
105             if !found {
106                 return None;
107             }
108             Some(unsafe { $name$(value.assume_init()) })
109         }
110 
111         fn map_remove(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> bool {
112             let mut value = $std$::mem::MaybeUninit::uninit();
113             unsafe { $pbr$::$map_remove_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.as_mut_ptr()) }
114         }
115 
116         fn map_iter(map: $pb$::MapView<$key_t$, Self>) -> $pb$::MapIter<$key_t$, Self> {
117             // SAFETY:
118             // - The backing map for `map.as_raw` is valid for at least '_.
119             // - A View that is live for '_ guarantees the backing map is unmodified for '_.
120             // - The `iter` function produces an iterator that is valid for the key
121             //   and value types, and live for at least '_.
122             unsafe {
123                 $pb$::MapIter::from_raw(
124                     $pbi$::Private,
125                     $pbr$::$map_iter_thunk$(map.as_raw($pbi$::Private))
126                 )
127             }
128         }
129 
130         fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> $Option$<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> {
131             // SAFETY:
132             // - The `MapIter` API forbids the backing map from being mutated for 'a,
133             //   and guarantees that it's the correct key and value types.
134             // - The thunk is safe to call as long as the iterator isn't at the end.
135             // - The thunk always writes to key and value fields and does not read.
136             // - The thunk does not increment the iterator.
137             unsafe {
138                 iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>(
139                     |iter, key, value| { $pbr$::$map_iter_get_thunk$(iter, key, value) },
140                     |ffi_key| $from_ffi_key_expr$,
141                     |value| $name$(value),
142                 )
143             }
144         }
145       }
146       )rs");
147       }
148       return;
149     case Kernel::kUpb:
150       ctx.Emit(R"rs(
151             impl $pbr$::UpbTypeConversions for $name$ {
152                 fn upb_type() -> $pbr$::CType {
153                     $pbr$::CType::Enum
154                 }
155 
156                 fn to_message_value(
157                     val: $pb$::View<'_, Self>) -> $pbr$::upb_MessageValue {
158                     $pbr$::upb_MessageValue { int32_val: val.0 }
159                 }
160 
161                 unsafe fn into_message_value_fuse_if_required(
162                   raw_parent_arena: $pbr$::RawArena,
163                   val: Self) -> $pbr$::upb_MessageValue {
164                     $pbr$::upb_MessageValue { int32_val: val.0 }
165                 }
166 
167                 unsafe fn from_message_value<'msg>(val: $pbr$::upb_MessageValue)
168                     -> $pb$::View<'msg, Self> {
169                   $name$(unsafe { val.int32_val })
170                 }
171             }
172             )rs");
173       return;
174   }
175 }
176 
177 }  // namespace
178 
EnumValues(absl::string_view enum_name,absl::Span<const std::pair<absl::string_view,int32_t>> values)179 std::vector<RustEnumValue> EnumValues(
180     absl::string_view enum_name,
181     absl::Span<const std::pair<absl::string_view, int32_t>> values) {
182   MultiCasePrefixStripper stripper(enum_name);
183 
184   absl::flat_hash_set<std::string> seen_by_name;
185   absl::flat_hash_map<int32_t, RustEnumValue*> seen_by_number;
186   std::vector<RustEnumValue> result;
187   // The below code depends on pointer stability of elements in `result`;
188   // this reserve must not be too low.
189   result.reserve(values.size());
190   seen_by_name.reserve(values.size());
191   seen_by_number.reserve(values.size());
192 
193   for (const auto& name_and_number : values) {
194     int32_t number = name_and_number.second;
195     std::string rust_value_name =
196         EnumValueRsName(stripper, name_and_number.first);
197 
198     if (seen_by_name.contains(rust_value_name)) {
199       // Don't add an alias with the same normalized name.
200       continue;
201     }
202 
203     auto it_and_inserted = seen_by_number.try_emplace(number);
204     if (it_and_inserted.second) {
205       // This is the first value with this number; this name is canonical.
206       result.push_back(RustEnumValue{rust_value_name, number});
207       it_and_inserted.first->second = &result.back();
208     } else {
209       // This number has been seen before; this name is an alias.
210       it_and_inserted.first->second->aliases.push_back(rust_value_name);
211     }
212 
213     seen_by_name.insert(std::move(rust_value_name));
214   }
215   return result;
216 }
217 
GenerateEnumDefinition(Context & ctx,const EnumDescriptor & desc)218 void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) {
219   std::string name = EnumRsName(desc);
220   ABSL_CHECK(desc.value_count() > 0);
221   std::vector<RustEnumValue> values =
222       EnumValues(desc.name(), EnumValuesInput(desc));
223   ABSL_CHECK(!values.empty());
224 
225   ctx.Emit(
226       {
227           {"name", name},
228           {"variants",
229            [&] {
230              for (const auto& value : values) {
231                std::string number_str = absl::StrCat(value.number);
232                // TODO: Replace with open enum variants when stable
233                ctx.Emit({{"variant_name", value.name}, {"number", number_str}},
234                         R"rs(
235                     pub const $variant_name$: $name$ = $name$($number$);
236                     )rs");
237                for (const auto& alias : value.aliases) {
238                  ctx.Emit({{"alias_name", alias}, {"number", number_str}},
239                           R"rs(
240                             pub const $alias_name$: $name$ = $name$($number$);
241                             )rs");
242                }
243              }
244            }},
245           // The default value of an enum is the first listed value.
246           // The compiler checks that this is equal to 0 for open enums.
247           {"default_int_value", absl::StrCat(desc.value(0)->number())},
248           {"known_values_pattern",
249            // TODO: Check validity in UPB/C++.
250            absl::StrJoin(values, "|",
251                          [](std::string* o, const RustEnumValue& val) {
252                            absl::StrAppend(o, val.number);
253                          })},
254           {"impl_from_i32",
255            [&] {
256              if (desc.is_closed()) {
257                ctx.Emit(R"rs(
258               impl $std$::convert::TryFrom<i32> for $name$ {
259                 type Error = $pb$::UnknownEnumValue<Self>;
260 
261                 fn try_from(val: i32) -> $Result$<$name$, Self::Error> {
262                   if <Self as $pbi$::Enum>::is_known(val) {
263                     Ok(Self(val))
264                   } else {
265                     Err($pb$::UnknownEnumValue::new($pbi$::Private, val))
266                   }
267                 }
268               }
269             )rs");
270              } else {
271                ctx.Emit(R"rs(
272               impl $std$::convert::From<i32> for $name$ {
273                 fn from(val: i32) -> $name$ {
274                   Self(val)
275                 }
276               }
277             )rs");
278              }
279            }},
280           {"impl_proxied_in_map", [&] { EnumProxiedInMapValue(ctx, desc); }},
281       },
282       R"rs(
283       #[repr(transparent)]
284       #[derive(Clone, Copy, PartialEq, Eq)]
285       pub struct $name$(i32);
286 
287       #[allow(non_upper_case_globals)]
288       impl $name$ {
289         $variants$
290       }
291 
292       impl $std$::convert::From<$name$> for i32 {
293         fn from(val: $name$) -> i32 {
294           val.0
295         }
296       }
297 
298       $impl_from_i32$
299 
300       impl $std$::default::Default for $name$ {
301         fn default() -> Self {
302           Self($default_int_value$)
303         }
304       }
305 
306       impl $std$::fmt::Debug for $name$ {
307         fn fmt(&self, f: &mut $std$::fmt::Formatter<'_>) -> $std$::fmt::Result {
308           f.debug_tuple(stringify!($name$)).field(&self.0).finish()
309         }
310       }
311 
312       impl $pb$::IntoProxied<i32> for $name$ {
313         fn into_proxied(self, _: $pbi$::Private) -> i32 {
314           self.0
315         }
316       }
317 
318       impl $pbi$::SealedInternal for $name$ {}
319 
320       impl $pb$::Proxied for $name$ {
321         type View<'a> = $name$;
322       }
323 
324       impl $pb$::Proxy<'_> for $name$ {}
325       impl $pb$::ViewProxy<'_> for $name$ {}
326 
327       impl $pb$::AsView for $name$ {
328         type Proxied = $name$;
329 
330         fn as_view(&self) -> $name$ {
331           *self
332         }
333       }
334 
335       impl<'msg> $pb$::IntoView<'msg> for $name$ {
336         fn into_view<'shorter>(self) -> $name$ where 'msg: 'shorter {
337           self
338         }
339       }
340 
341       unsafe impl $pb$::ProxiedInRepeated for $name$ {
342         fn repeated_new(_private: $pbi$::Private) -> $pb$::Repeated<Self> {
343           $pbr$::new_enum_repeated()
344         }
345 
346         unsafe fn repeated_free(_private: $pbi$::Private, f: &mut $pb$::Repeated<Self>) {
347           $pbr$::free_enum_repeated(f)
348         }
349 
350         fn repeated_len(r: $pb$::View<$pb$::Repeated<Self>>) -> usize {
351           $pbr$::cast_enum_repeated_view(r).len()
352         }
353 
354         fn repeated_push(r: $pb$::Mut<$pb$::Repeated<Self>>, val: impl $pb$::IntoProxied<$name$>) {
355           $pbr$::cast_enum_repeated_mut(r).push(val.into_proxied($pbi$::Private))
356         }
357 
358         fn repeated_clear(r: $pb$::Mut<$pb$::Repeated<Self>>) {
359           $pbr$::cast_enum_repeated_mut(r).clear()
360         }
361 
362         unsafe fn repeated_get_unchecked(
363             r: $pb$::View<$pb$::Repeated<Self>>,
364             index: usize,
365         ) -> $pb$::View<$name$> {
366           // SAFETY: In-bounds as promised by the caller.
367           unsafe {
368             $pbr$::cast_enum_repeated_view(r)
369               .get_unchecked(index)
370               .try_into()
371               .unwrap_unchecked()
372           }
373         }
374 
375         unsafe fn repeated_set_unchecked(
376             r: $pb$::Mut<$pb$::Repeated<Self>>,
377             index: usize,
378             val: impl $pb$::IntoProxied<$name$>,
379         ) {
380           // SAFETY: In-bounds as promised by the caller.
381           unsafe {
382             $pbr$::cast_enum_repeated_mut(r)
383               .set_unchecked(index, val.into_proxied($pbi$::Private))
384           }
385         }
386 
387         fn repeated_copy_from(
388             src: $pb$::View<$pb$::Repeated<Self>>,
389             dest: $pb$::Mut<$pb$::Repeated<Self>>,
390         ) {
391           $pbr$::cast_enum_repeated_mut(dest)
392             .copy_from($pbr$::cast_enum_repeated_view(src))
393         }
394 
395         fn repeated_reserve(
396             r: $pb$::Mut<$pb$::Repeated<Self>>,
397             additional: usize,
398         ) {
399             // SAFETY:
400             // - `f.as_raw()` is valid.
401             $pbr$::reserve_enum_repeated_mut(r, additional);
402         }
403       }
404 
405       // SAFETY: this is an enum type
406       unsafe impl $pbi$::Enum for $name$ {
407         const NAME: &'static str = "$name$";
408 
409         fn is_known(value: i32) -> bool {
410           matches!(value, $known_values_pattern$)
411         }
412       }
413 
414       $impl_proxied_in_map$
415       )rs");
416 }
417 
418 }  // namespace rust
419 }  // namespace compiler
420 }  // namespace protobuf
421 }  // namespace google
422