• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2023 Google LLC.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 #include "google/protobuf/compiler/rust/naming.h"
9 
10 #include <algorithm>
11 #include <string>
12 #include <vector>
13 
14 #include "absl/log/absl_check.h"
15 #include "absl/log/absl_log.h"
16 #include "absl/strings/ascii.h"
17 #include "absl/strings/match.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_format.h"
20 #include "absl/strings/str_join.h"
21 #include "absl/strings/str_replace.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/strip.h"
24 #include "absl/strings/substitute.h"
25 #include "google/protobuf/compiler/code_generator.h"
26 #include "google/protobuf/compiler/cpp/helpers.h"
27 #include "google/protobuf/compiler/rust/context.h"
28 #include "google/protobuf/compiler/rust/rust_field_type.h"
29 #include "google/protobuf/compiler/rust/rust_keywords.h"
30 #include "google/protobuf/descriptor.h"
31 #include "google/protobuf/port.h"
32 
33 // Must be included last.
34 #include "google/protobuf/port_def.inc"
35 
36 namespace google {
37 namespace protobuf {
38 namespace compiler {
39 namespace rust {
40 
GetCrateName(Context & ctx,const FileDescriptor & dep)41 std::string GetCrateName(Context& ctx, const FileDescriptor& dep) {
42   return RsSafeName(ctx.ImportPathToCrateName(dep.name()));
43 }
44 
GetRsFile(Context & ctx,const FileDescriptor & file)45 std::string GetRsFile(Context& ctx, const FileDescriptor& file) {
46   auto basename = StripProto(file.name());
47   switch (auto k = ctx.opts().kernel) {
48     case Kernel::kUpb:
49       return absl::StrCat(basename, ".u.pb.rs");
50     case Kernel::kCpp:
51       return absl::StrCat(basename, ".c.pb.rs");
52     default:
53       ABSL_LOG(FATAL) << "Unknown kernel type: " << static_cast<int>(k);
54       return "";
55   }
56 }
57 
GetThunkCcFile(Context & ctx,const FileDescriptor & file)58 std::string GetThunkCcFile(Context& ctx, const FileDescriptor& file) {
59   auto basename = StripProto(file.name());
60   return absl::StrCat(basename, ".pb.thunks.cc");
61 }
62 
GetHeaderFile(Context & ctx,const FileDescriptor & file)63 std::string GetHeaderFile(Context& ctx, const FileDescriptor& file) {
64   auto basename = StripProto(file.name());
65   constexpr absl::string_view kCcGencodeExt = ".pb.h";
66 
67   return absl::StrCat(basename, kCcGencodeExt);
68 }
69 
RawMapThunk(Context & ctx,const Descriptor & msg,absl::string_view key_t,absl::string_view op)70 std::string RawMapThunk(Context& ctx, const Descriptor& msg,
71                         absl::string_view key_t, absl::string_view op) {
72   return absl::StrCat("proto2_rust_thunk_Map_", key_t, "_",
73                       GetUnderscoreDelimitedFullName(ctx, *&msg), "_", op);
74 }
75 
RawMapThunk(Context & ctx,const EnumDescriptor & desc,absl::string_view key_t,absl::string_view op)76 std::string RawMapThunk(Context& ctx, const EnumDescriptor& desc,
77                         absl::string_view key_t, absl::string_view op) {
78   // Enums are always 32 bits.
79   return absl::StrCat("proto2_rust_thunk_Map_", key_t, "_i32_", op);
80 }
81 
ThunkName(Context & ctx,const FieldDescriptor & field,absl::string_view op)82 std::string ThunkName(Context& ctx, const FieldDescriptor& field,
83                       absl::string_view op) {
84   ABSL_CHECK(ctx.is_cpp());
85   return absl::StrCat("proto2_rust_thunk_",
86                       UnderscoreDelimitFullName(ctx, field.full_name()), "_",
87                       op);
88 }
89 
ThunkName(Context & ctx,const OneofDescriptor & field,absl::string_view op)90 std::string ThunkName(Context& ctx, const OneofDescriptor& field,
91                       absl::string_view op) {
92   ABSL_CHECK(ctx.is_cpp());
93   return absl::StrCat("proto2_rust_thunk_",
94                       UnderscoreDelimitFullName(ctx, field.full_name()), "_",
95                       op);
96 }
97 
ThunkName(Context & ctx,const Descriptor & msg,absl::string_view op)98 std::string ThunkName(Context& ctx, const Descriptor& msg,
99                       absl::string_view op) {
100   absl::string_view prefix = ctx.is_cpp() ? "proto2_rust_thunk_Message_" : "";
101   return absl::StrCat(prefix, GetUnderscoreDelimitedFullName(ctx, msg), "_",
102                       op);
103 }
104 
105 template <typename Desc>
GetFullyQualifiedPath(Context & ctx,const Desc & desc)106 std::string GetFullyQualifiedPath(Context& ctx, const Desc& desc) {
107   auto rel_path = GetCrateRelativeQualifiedPath(ctx, desc);
108   if (IsInCurrentlyGeneratingCrate(ctx, desc)) {
109     return absl::StrCat("crate::", rel_path);
110   }
111   return absl::StrCat(GetCrateName(ctx, *desc.file()), "::", rel_path);
112 }
113 
114 template <typename Desc>
GetUnderscoreDelimitedFullName(Context & ctx,const Desc & desc)115 std::string GetUnderscoreDelimitedFullName(Context& ctx, const Desc& desc) {
116   return UnderscoreDelimitFullName(ctx, desc.full_name());
117 }
118 
UnderscoreDelimitFullName(Context & ctx,absl::string_view full_name)119 std::string UnderscoreDelimitFullName(Context& ctx,
120                                       absl::string_view full_name) {
121   std::string result = std::string(full_name);
122   absl::StrReplaceAll({{".", "_"}}, &result);
123   return result;
124 }
125 
RsTypePath(Context & ctx,const FieldDescriptor & field)126 std::string RsTypePath(Context& ctx, const FieldDescriptor& field) {
127   switch (GetRustFieldType(field)) {
128     case RustFieldType::BOOL:
129       return "bool";
130     case RustFieldType::INT32:
131       return "i32";
132     case RustFieldType::INT64:
133       return "i64";
134     case RustFieldType::UINT32:
135       return "u32";
136     case RustFieldType::UINT64:
137       return "u64";
138     case RustFieldType::FLOAT:
139       return "f32";
140     case RustFieldType::DOUBLE:
141       return "f64";
142     case RustFieldType::BYTES:
143       return "::__pb::ProtoBytes";
144     case RustFieldType::STRING:
145       return "::__pb::ProtoString";
146     case RustFieldType::MESSAGE:
147       return GetFullyQualifiedPath(ctx, *field.message_type());
148     case RustFieldType::ENUM:
149       return GetFullyQualifiedPath(ctx, *field.enum_type());
150   }
151   ABSL_LOG(ERROR) << "Unknown field type: " << field.type_name();
152   internal::Unreachable();
153 }
154 
RsViewType(Context & ctx,const FieldDescriptor & field,absl::string_view lifetime)155 std::string RsViewType(Context& ctx, const FieldDescriptor& field,
156                        absl::string_view lifetime) {
157   switch (GetRustFieldType(field)) {
158     case RustFieldType::BOOL:
159     case RustFieldType::INT32:
160     case RustFieldType::INT64:
161     case RustFieldType::UINT32:
162     case RustFieldType::UINT64:
163     case RustFieldType::FLOAT:
164     case RustFieldType::DOUBLE:
165     case RustFieldType::ENUM:
166       // The View type of all scalars and enums can be spelled as the type
167       // itself.
168       return RsTypePath(ctx, field);
169     case RustFieldType::BYTES:
170       return absl::StrFormat("&%s [u8]", lifetime);
171     case RustFieldType::STRING:
172       return absl::StrFormat("&%s ::__pb::ProtoStr", lifetime);
173     case RustFieldType::MESSAGE:
174       if (lifetime.empty()) {
175         return absl::StrFormat(
176             "%sView", GetFullyQualifiedPath(ctx, *field.message_type()));
177       } else {
178         return absl::StrFormat(
179             "%sView<%s>", GetFullyQualifiedPath(ctx, *field.message_type()),
180             lifetime);
181       }
182   }
183   ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name();
184   internal::Unreachable();
185 }
186 
RustModuleForContainingType(Context & ctx,const Descriptor * containing_type)187 std::string RustModuleForContainingType(Context& ctx,
188                                         const Descriptor* containing_type) {
189   std::vector<std::string> modules;
190 
191   // Innermost to outermost order.
192   const Descriptor* parent = containing_type;
193   while (parent != nullptr) {
194     modules.push_back(RsSafeName(CamelToSnakeCase(parent->name())));
195     parent = parent->containing_type();
196   }
197 
198   // Reverse the vector to get submodules in outer-to-inner order).
199   std::reverse(modules.begin(), modules.end());
200 
201   // If there are any modules at all, push an empty string on the end so that
202   // we get the trailing ::
203   if (!modules.empty()) {
204     modules.push_back("");
205   }
206 
207   return absl::StrJoin(modules, "::");
208 }
209 
RustModule(Context & ctx,const Descriptor & msg)210 std::string RustModule(Context& ctx, const Descriptor& msg) {
211   return RustModuleForContainingType(ctx, msg.containing_type());
212 }
213 
RustModule(Context & ctx,const EnumDescriptor & enum_)214 std::string RustModule(Context& ctx, const EnumDescriptor& enum_) {
215   return RustModuleForContainingType(ctx, enum_.containing_type());
216 }
217 
RustModule(Context & ctx,const OneofDescriptor & oneof)218 std::string RustModule(Context& ctx, const OneofDescriptor& oneof) {
219   return RustModuleForContainingType(ctx, oneof.containing_type());
220 }
221 
RustInternalModuleName(Context & ctx,const FileDescriptor & file)222 std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file) {
223   return RsSafeName(
224       absl::StrReplaceAll(StripProto(file.name()), {{"_", "__"}, {"/", "_s"}}));
225 }
226 
GetCrateRelativeQualifiedPath(Context & ctx,const Descriptor & msg)227 std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg) {
228   return absl::StrCat(RustModule(ctx, msg), RsSafeName(msg.name()));
229 }
230 
GetCrateRelativeQualifiedPath(Context & ctx,const EnumDescriptor & enum_)231 std::string GetCrateRelativeQualifiedPath(Context& ctx,
232                                           const EnumDescriptor& enum_) {
233   return absl::StrCat(RustModule(ctx, enum_), EnumRsName(enum_));
234 }
235 
FieldInfoComment(Context & ctx,const FieldDescriptor & field)236 std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field) {
237   absl::string_view label = field.is_repeated() ? "repeated" : "optional";
238   std::string comment = absl::StrCat(field.name(), ": ", label, " ",
239                                      FieldDescriptor::TypeName(field.type()));
240 
241   if (auto* m = field.message_type()) {
242     absl::StrAppend(&comment, " ", m->full_name());
243   }
244   if (auto* m = field.enum_type()) {
245     absl::StrAppend(&comment, " ", m->full_name());
246   }
247 
248   return comment;
249 }
250 
251 static constexpr absl::string_view kAccessorPrefixes[] = {"clear_", "has_",
252                                                           "set_"};
253 
254 static constexpr absl::string_view kAccessorSuffixes[] = {"_mut", "_opt"};
255 
FieldNameWithCollisionAvoidance(const FieldDescriptor & field)256 std::string FieldNameWithCollisionAvoidance(const FieldDescriptor& field) {
257   absl::string_view name = field.name();
258   const Descriptor& msg = *field.containing_type();
259 
260   for (absl::string_view prefix : kAccessorPrefixes) {
261     if (absl::StartsWith(name, prefix)) {
262       absl::string_view without_prefix = name;
263       without_prefix.remove_prefix(prefix.size());
264 
265       if (msg.FindFieldByName(without_prefix) != nullptr) {
266         return absl::StrCat(name, "_", field.number());
267       }
268     }
269   }
270 
271   for (absl::string_view suffix : kAccessorSuffixes) {
272     if (absl::EndsWith(name, suffix)) {
273       absl::string_view without_suffix = name;
274       without_suffix.remove_suffix(suffix.size());
275 
276       if (msg.FindFieldByName(without_suffix) != nullptr) {
277         return absl::StrCat(name, "_", field.number());
278       }
279     }
280   }
281 
282   return std::string(name);
283 }
284 
RsSafeName(absl::string_view name)285 std::string RsSafeName(absl::string_view name) {
286   if (!IsLegalRawIdentifierName(name)) {
287     return absl::StrCat(name,
288                         "__mangled_because_ident_isnt_a_legal_raw_identifier");
289   }
290   if (IsRustKeyword(name)) {
291     return absl::StrCat("r#", name);
292   }
293   return std::string(name);
294 }
295 
EnumRsName(const EnumDescriptor & desc)296 std::string EnumRsName(const EnumDescriptor& desc) {
297   return RsSafeName(SnakeToUpperCamelCase(desc.name()));
298 }
299 
EnumValueRsName(const EnumValueDescriptor & value)300 std::string EnumValueRsName(const EnumValueDescriptor& value) {
301   MultiCasePrefixStripper stripper(value.type()->name());
302   return EnumValueRsName(stripper, value.name());
303 }
304 
EnumValueRsName(const MultiCasePrefixStripper & stripper,absl::string_view value_name)305 std::string EnumValueRsName(const MultiCasePrefixStripper& stripper,
306                             absl::string_view value_name) {
307   // Enum values may have a prefix of the name of the enum stripped from the
308   // value names in the gencode. This prefix is flexible:
309   // - It can be the original enum name, the name as UpperCamel, or snake_case.
310   // - The stripped prefix may also end in an underscore.
311   auto stripped = stripper.StripPrefix(value_name);
312 
313   auto name = ScreamingSnakeToUpperCamelCase(stripped);
314   ABSL_CHECK(!name.empty());
315 
316   // Invalid identifiers are prefixed with `_`.
317   if (absl::ascii_isdigit(name[0])) {
318     name = absl::StrCat("_", name);
319   }
320   return RsSafeName(name);
321 }
322 
OneofViewEnumRsName(const OneofDescriptor & oneof)323 std::string OneofViewEnumRsName(const OneofDescriptor& oneof) {
324   return RsSafeName(SnakeToUpperCamelCase(oneof.name()));
325 }
326 
OneofCaseEnumRsName(const OneofDescriptor & oneof)327 std::string OneofCaseEnumRsName(const OneofDescriptor& oneof) {
328   // Note: This is the name used for the cpp Case enum, we use it for both
329   // the Rust Case enum as well as for the cpp case enum in the cpp thunk.
330   return SnakeToUpperCamelCase(oneof.name()) + "Case";
331 }
332 
OneofCaseRsName(const FieldDescriptor & oneof_field)333 std::string OneofCaseRsName(const FieldDescriptor& oneof_field) {
334   return RsSafeName(SnakeToUpperCamelCase(oneof_field.name()));
335 }
336 
CamelToSnakeCase(absl::string_view input)337 std::string CamelToSnakeCase(absl::string_view input) {
338   std::string result;
339   result.reserve(input.size() + 4);  // No reallocation for 4 _
340   bool is_first_character = true;
341   bool last_char_was_underscore = false;
342   for (const char c : input) {
343     if (!is_first_character && absl::ascii_isupper(c) &&
344         !last_char_was_underscore) {
345       result += '_';
346     }
347     last_char_was_underscore = c == '_';
348     result += absl::ascii_tolower(c);
349     is_first_character = false;
350   }
351   return result;
352 }
353 
SnakeToUpperCamelCase(absl::string_view input)354 std::string SnakeToUpperCamelCase(absl::string_view input) {
355   return cpp::UnderscoresToCamelCase(input, /*cap first letter=*/true);
356 }
357 
ScreamingSnakeToUpperCamelCase(absl::string_view input)358 std::string ScreamingSnakeToUpperCamelCase(absl::string_view input) {
359   std::string result;
360   result.reserve(input.size());
361   bool cap_next_letter = true;
362   for (const char c : input) {
363     if (absl::ascii_isalpha(c)) {
364       if (cap_next_letter) {
365         result += absl::ascii_toupper(c);
366       } else {
367         result += absl::ascii_tolower(c);
368       }
369       cap_next_letter = false;
370     } else if (absl::ascii_isdigit(c)) {
371       result += c;
372       cap_next_letter = true;
373     } else {
374       cap_next_letter = true;
375     }
376   }
377   return result;
378 }
379 
MultiCasePrefixStripper(absl::string_view prefix)380 MultiCasePrefixStripper::MultiCasePrefixStripper(absl::string_view prefix)
381     : prefixes_{
382           std::string(prefix),
383           ScreamingSnakeToUpperCamelCase(prefix),
384           CamelToSnakeCase(prefix),
385       } {}
386 
StripPrefix(absl::string_view name) const387 absl::string_view MultiCasePrefixStripper::StripPrefix(
388     absl::string_view name) const {
389   absl::string_view start_name = name;
390   for (absl::string_view prefix : prefixes_) {
391     if (absl::StartsWithIgnoreCase(name, prefix)) {
392       name.remove_prefix(prefix.size());
393 
394       // Also strip a joining underscore, if present.
395       absl::ConsumePrefix(&name, "_");
396 
397       // Only strip one prefix.
398       break;
399     }
400   }
401 
402   if (name.empty()) {
403     return start_name;
404   }
405   return name;
406 }
407 
408 PROTOBUF_CONSTINIT const MapKeyType kMapKeyTypes[] = {
409     {/*thunk_ident=*/"i32", /*rs_key_t=*/"i32", /*rs_ffi_key_t=*/"i32",
410      /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
411      /*cc_key_t=*/"int32_t", /*cc_ffi_key_t=*/"int32_t",
412      /*cc_from_ffi_key_expr=*/"key",
413      /*cc_to_ffi_key_expr=*/"cpp_key"},
414     {/*thunk_ident=*/"u32", /*rs_key_t=*/"u32", /*rs_ffi_key_t=*/"u32",
415      /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
416      /*cc_key_t=*/"uint32_t", /*cc_ffi_key_t=*/"uint32_t",
417      /*cc_from_ffi_key_expr=*/"key",
418      /*cc_to_ffi_key_expr=*/"cpp_key"},
419     {/*thunk_ident=*/"i64", /*rs_key_t=*/"i64", /*rs_ffi_key_t=*/"i64",
420      /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
421      /*cc_key_t=*/"int64_t", /*cc_ffi_key_t=*/"int64_t",
422      /*cc_from_ffi_key_expr=*/"key",
423      /*cc_to_ffi_key_expr=*/"cpp_key"},
424     {/*thunk_ident=*/"u64", /*rs_key_t=*/"u64", /*rs_ffi_key_t=*/"u64",
425      /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
426      /*cc_key_t=*/"uint64_t", /*cc_ffi_key_t=*/"uint64_t",
427      /*cc_from_ffi_key_expr=*/"key",
428      /*cc_to_ffi_key_expr=*/"cpp_key"},
429     {/*thunk_ident=*/"bool", /*rs_key_t=*/"bool", /*rs_ffi_key_t=*/"bool",
430      /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
431      /*cc_key_t=*/"bool", /*cc_ffi_key_t=*/"bool",
432      /*cc_from_ffi_key_expr=*/"key",
433      /*cc_to_ffi_key_expr=*/"cpp_key"},
434     {/*thunk_ident=*/"ProtoString",
435      /*rs_key_t=*/"$pb$::ProtoString",
436      /*rs_ffi_key_t=*/"$pbr$::PtrAndLen",
437      /*rs_to_ffi_key_expr=*/"key.as_bytes().into()",
438      /*rs_from_ffi_key_expr=*/
439      "$pb$::ProtoStr::from_utf8_unchecked(ffi_key.as_ref())",
440      /*cc_key_t=*/"std::string",
441      /*cc_ffi_key_t=*/"google::protobuf::rust::PtrAndLen",
442      /*cc_from_ffi_key_expr=*/
443      "std::string(key.ptr, key.len)", /*cc_to_ffi_key_expr=*/
444      "google::protobuf::rust::PtrAndLen{cpp_key.data(), cpp_key.size()}"}};
445 
446 }  // namespace rust
447 }  // namespace compiler
448 }  // namespace protobuf
449 }  // namespace google
450 
451 #include "google/protobuf/port_undef.inc"
452