1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2023 Google LLC. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 #include "google/protobuf/compiler/rust/naming.h"
9
10 #include <algorithm>
11 #include <string>
12 #include <vector>
13
14 #include "absl/log/absl_check.h"
15 #include "absl/log/absl_log.h"
16 #include "absl/strings/ascii.h"
17 #include "absl/strings/match.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_format.h"
20 #include "absl/strings/str_join.h"
21 #include "absl/strings/str_replace.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/strip.h"
24 #include "absl/strings/substitute.h"
25 #include "google/protobuf/compiler/code_generator.h"
26 #include "google/protobuf/compiler/cpp/helpers.h"
27 #include "google/protobuf/compiler/rust/context.h"
28 #include "google/protobuf/compiler/rust/rust_field_type.h"
29 #include "google/protobuf/compiler/rust/rust_keywords.h"
30 #include "google/protobuf/descriptor.h"
31 #include "google/protobuf/port.h"
32
33 // Must be included last.
34 #include "google/protobuf/port_def.inc"
35
36 namespace google {
37 namespace protobuf {
38 namespace compiler {
39 namespace rust {
40
GetCrateName(Context & ctx,const FileDescriptor & dep)41 std::string GetCrateName(Context& ctx, const FileDescriptor& dep) {
42 return RsSafeName(ctx.ImportPathToCrateName(dep.name()));
43 }
44
GetRsFile(Context & ctx,const FileDescriptor & file)45 std::string GetRsFile(Context& ctx, const FileDescriptor& file) {
46 auto basename = StripProto(file.name());
47 switch (auto k = ctx.opts().kernel) {
48 case Kernel::kUpb:
49 return absl::StrCat(basename, ".u.pb.rs");
50 case Kernel::kCpp:
51 return absl::StrCat(basename, ".c.pb.rs");
52 default:
53 ABSL_LOG(FATAL) << "Unknown kernel type: " << static_cast<int>(k);
54 return "";
55 }
56 }
57
GetThunkCcFile(Context & ctx,const FileDescriptor & file)58 std::string GetThunkCcFile(Context& ctx, const FileDescriptor& file) {
59 auto basename = StripProto(file.name());
60 return absl::StrCat(basename, ".pb.thunks.cc");
61 }
62
GetHeaderFile(Context & ctx,const FileDescriptor & file)63 std::string GetHeaderFile(Context& ctx, const FileDescriptor& file) {
64 auto basename = StripProto(file.name());
65 constexpr absl::string_view kCcGencodeExt = ".pb.h";
66
67 return absl::StrCat(basename, kCcGencodeExt);
68 }
69
RawMapThunk(Context & ctx,const Descriptor & msg,absl::string_view key_t,absl::string_view op)70 std::string RawMapThunk(Context& ctx, const Descriptor& msg,
71 absl::string_view key_t, absl::string_view op) {
72 return absl::StrCat("proto2_rust_thunk_Map_", key_t, "_",
73 GetUnderscoreDelimitedFullName(ctx, *&msg), "_", op);
74 }
75
RawMapThunk(Context & ctx,const EnumDescriptor & desc,absl::string_view key_t,absl::string_view op)76 std::string RawMapThunk(Context& ctx, const EnumDescriptor& desc,
77 absl::string_view key_t, absl::string_view op) {
78 // Enums are always 32 bits.
79 return absl::StrCat("proto2_rust_thunk_Map_", key_t, "_i32_", op);
80 }
81
ThunkName(Context & ctx,const FieldDescriptor & field,absl::string_view op)82 std::string ThunkName(Context& ctx, const FieldDescriptor& field,
83 absl::string_view op) {
84 ABSL_CHECK(ctx.is_cpp());
85 return absl::StrCat("proto2_rust_thunk_",
86 UnderscoreDelimitFullName(ctx, field.full_name()), "_",
87 op);
88 }
89
ThunkName(Context & ctx,const OneofDescriptor & field,absl::string_view op)90 std::string ThunkName(Context& ctx, const OneofDescriptor& field,
91 absl::string_view op) {
92 ABSL_CHECK(ctx.is_cpp());
93 return absl::StrCat("proto2_rust_thunk_",
94 UnderscoreDelimitFullName(ctx, field.full_name()), "_",
95 op);
96 }
97
ThunkName(Context & ctx,const Descriptor & msg,absl::string_view op)98 std::string ThunkName(Context& ctx, const Descriptor& msg,
99 absl::string_view op) {
100 absl::string_view prefix = ctx.is_cpp() ? "proto2_rust_thunk_Message_" : "";
101 return absl::StrCat(prefix, GetUnderscoreDelimitedFullName(ctx, msg), "_",
102 op);
103 }
104
105 template <typename Desc>
GetFullyQualifiedPath(Context & ctx,const Desc & desc)106 std::string GetFullyQualifiedPath(Context& ctx, const Desc& desc) {
107 auto rel_path = GetCrateRelativeQualifiedPath(ctx, desc);
108 if (IsInCurrentlyGeneratingCrate(ctx, desc)) {
109 return absl::StrCat("crate::", rel_path);
110 }
111 return absl::StrCat(GetCrateName(ctx, *desc.file()), "::", rel_path);
112 }
113
114 template <typename Desc>
GetUnderscoreDelimitedFullName(Context & ctx,const Desc & desc)115 std::string GetUnderscoreDelimitedFullName(Context& ctx, const Desc& desc) {
116 return UnderscoreDelimitFullName(ctx, desc.full_name());
117 }
118
UnderscoreDelimitFullName(Context & ctx,absl::string_view full_name)119 std::string UnderscoreDelimitFullName(Context& ctx,
120 absl::string_view full_name) {
121 std::string result = std::string(full_name);
122 absl::StrReplaceAll({{".", "_"}}, &result);
123 return result;
124 }
125
RsTypePath(Context & ctx,const FieldDescriptor & field)126 std::string RsTypePath(Context& ctx, const FieldDescriptor& field) {
127 switch (GetRustFieldType(field)) {
128 case RustFieldType::BOOL:
129 return "bool";
130 case RustFieldType::INT32:
131 return "i32";
132 case RustFieldType::INT64:
133 return "i64";
134 case RustFieldType::UINT32:
135 return "u32";
136 case RustFieldType::UINT64:
137 return "u64";
138 case RustFieldType::FLOAT:
139 return "f32";
140 case RustFieldType::DOUBLE:
141 return "f64";
142 case RustFieldType::BYTES:
143 return "::__pb::ProtoBytes";
144 case RustFieldType::STRING:
145 return "::__pb::ProtoString";
146 case RustFieldType::MESSAGE:
147 return GetFullyQualifiedPath(ctx, *field.message_type());
148 case RustFieldType::ENUM:
149 return GetFullyQualifiedPath(ctx, *field.enum_type());
150 }
151 ABSL_LOG(ERROR) << "Unknown field type: " << field.type_name();
152 internal::Unreachable();
153 }
154
RsViewType(Context & ctx,const FieldDescriptor & field,absl::string_view lifetime)155 std::string RsViewType(Context& ctx, const FieldDescriptor& field,
156 absl::string_view lifetime) {
157 switch (GetRustFieldType(field)) {
158 case RustFieldType::BOOL:
159 case RustFieldType::INT32:
160 case RustFieldType::INT64:
161 case RustFieldType::UINT32:
162 case RustFieldType::UINT64:
163 case RustFieldType::FLOAT:
164 case RustFieldType::DOUBLE:
165 case RustFieldType::ENUM:
166 // The View type of all scalars and enums can be spelled as the type
167 // itself.
168 return RsTypePath(ctx, field);
169 case RustFieldType::BYTES:
170 return absl::StrFormat("&%s [u8]", lifetime);
171 case RustFieldType::STRING:
172 return absl::StrFormat("&%s ::__pb::ProtoStr", lifetime);
173 case RustFieldType::MESSAGE:
174 if (lifetime.empty()) {
175 return absl::StrFormat(
176 "%sView", GetFullyQualifiedPath(ctx, *field.message_type()));
177 } else {
178 return absl::StrFormat(
179 "%sView<%s>", GetFullyQualifiedPath(ctx, *field.message_type()),
180 lifetime);
181 }
182 }
183 ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name();
184 internal::Unreachable();
185 }
186
RustModuleForContainingType(Context & ctx,const Descriptor * containing_type)187 std::string RustModuleForContainingType(Context& ctx,
188 const Descriptor* containing_type) {
189 std::vector<std::string> modules;
190
191 // Innermost to outermost order.
192 const Descriptor* parent = containing_type;
193 while (parent != nullptr) {
194 modules.push_back(RsSafeName(CamelToSnakeCase(parent->name())));
195 parent = parent->containing_type();
196 }
197
198 // Reverse the vector to get submodules in outer-to-inner order).
199 std::reverse(modules.begin(), modules.end());
200
201 // If there are any modules at all, push an empty string on the end so that
202 // we get the trailing ::
203 if (!modules.empty()) {
204 modules.push_back("");
205 }
206
207 return absl::StrJoin(modules, "::");
208 }
209
RustModule(Context & ctx,const Descriptor & msg)210 std::string RustModule(Context& ctx, const Descriptor& msg) {
211 return RustModuleForContainingType(ctx, msg.containing_type());
212 }
213
RustModule(Context & ctx,const EnumDescriptor & enum_)214 std::string RustModule(Context& ctx, const EnumDescriptor& enum_) {
215 return RustModuleForContainingType(ctx, enum_.containing_type());
216 }
217
RustModule(Context & ctx,const OneofDescriptor & oneof)218 std::string RustModule(Context& ctx, const OneofDescriptor& oneof) {
219 return RustModuleForContainingType(ctx, oneof.containing_type());
220 }
221
RustInternalModuleName(Context & ctx,const FileDescriptor & file)222 std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file) {
223 return RsSafeName(
224 absl::StrReplaceAll(StripProto(file.name()), {{"_", "__"}, {"/", "_s"}}));
225 }
226
GetCrateRelativeQualifiedPath(Context & ctx,const Descriptor & msg)227 std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg) {
228 return absl::StrCat(RustModule(ctx, msg), RsSafeName(msg.name()));
229 }
230
GetCrateRelativeQualifiedPath(Context & ctx,const EnumDescriptor & enum_)231 std::string GetCrateRelativeQualifiedPath(Context& ctx,
232 const EnumDescriptor& enum_) {
233 return absl::StrCat(RustModule(ctx, enum_), EnumRsName(enum_));
234 }
235
FieldInfoComment(Context & ctx,const FieldDescriptor & field)236 std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field) {
237 absl::string_view label = field.is_repeated() ? "repeated" : "optional";
238 std::string comment = absl::StrCat(field.name(), ": ", label, " ",
239 FieldDescriptor::TypeName(field.type()));
240
241 if (auto* m = field.message_type()) {
242 absl::StrAppend(&comment, " ", m->full_name());
243 }
244 if (auto* m = field.enum_type()) {
245 absl::StrAppend(&comment, " ", m->full_name());
246 }
247
248 return comment;
249 }
250
251 static constexpr absl::string_view kAccessorPrefixes[] = {"clear_", "has_",
252 "set_"};
253
254 static constexpr absl::string_view kAccessorSuffixes[] = {"_mut", "_opt"};
255
FieldNameWithCollisionAvoidance(const FieldDescriptor & field)256 std::string FieldNameWithCollisionAvoidance(const FieldDescriptor& field) {
257 absl::string_view name = field.name();
258 const Descriptor& msg = *field.containing_type();
259
260 for (absl::string_view prefix : kAccessorPrefixes) {
261 if (absl::StartsWith(name, prefix)) {
262 absl::string_view without_prefix = name;
263 without_prefix.remove_prefix(prefix.size());
264
265 if (msg.FindFieldByName(without_prefix) != nullptr) {
266 return absl::StrCat(name, "_", field.number());
267 }
268 }
269 }
270
271 for (absl::string_view suffix : kAccessorSuffixes) {
272 if (absl::EndsWith(name, suffix)) {
273 absl::string_view without_suffix = name;
274 without_suffix.remove_suffix(suffix.size());
275
276 if (msg.FindFieldByName(without_suffix) != nullptr) {
277 return absl::StrCat(name, "_", field.number());
278 }
279 }
280 }
281
282 return std::string(name);
283 }
284
RsSafeName(absl::string_view name)285 std::string RsSafeName(absl::string_view name) {
286 if (!IsLegalRawIdentifierName(name)) {
287 return absl::StrCat(name,
288 "__mangled_because_ident_isnt_a_legal_raw_identifier");
289 }
290 if (IsRustKeyword(name)) {
291 return absl::StrCat("r#", name);
292 }
293 return std::string(name);
294 }
295
EnumRsName(const EnumDescriptor & desc)296 std::string EnumRsName(const EnumDescriptor& desc) {
297 return RsSafeName(SnakeToUpperCamelCase(desc.name()));
298 }
299
EnumValueRsName(const EnumValueDescriptor & value)300 std::string EnumValueRsName(const EnumValueDescriptor& value) {
301 MultiCasePrefixStripper stripper(value.type()->name());
302 return EnumValueRsName(stripper, value.name());
303 }
304
EnumValueRsName(const MultiCasePrefixStripper & stripper,absl::string_view value_name)305 std::string EnumValueRsName(const MultiCasePrefixStripper& stripper,
306 absl::string_view value_name) {
307 // Enum values may have a prefix of the name of the enum stripped from the
308 // value names in the gencode. This prefix is flexible:
309 // - It can be the original enum name, the name as UpperCamel, or snake_case.
310 // - The stripped prefix may also end in an underscore.
311 auto stripped = stripper.StripPrefix(value_name);
312
313 auto name = ScreamingSnakeToUpperCamelCase(stripped);
314 ABSL_CHECK(!name.empty());
315
316 // Invalid identifiers are prefixed with `_`.
317 if (absl::ascii_isdigit(name[0])) {
318 name = absl::StrCat("_", name);
319 }
320 return RsSafeName(name);
321 }
322
OneofViewEnumRsName(const OneofDescriptor & oneof)323 std::string OneofViewEnumRsName(const OneofDescriptor& oneof) {
324 return RsSafeName(SnakeToUpperCamelCase(oneof.name()));
325 }
326
OneofCaseEnumRsName(const OneofDescriptor & oneof)327 std::string OneofCaseEnumRsName(const OneofDescriptor& oneof) {
328 // Note: This is the name used for the cpp Case enum, we use it for both
329 // the Rust Case enum as well as for the cpp case enum in the cpp thunk.
330 return SnakeToUpperCamelCase(oneof.name()) + "Case";
331 }
332
OneofCaseRsName(const FieldDescriptor & oneof_field)333 std::string OneofCaseRsName(const FieldDescriptor& oneof_field) {
334 return RsSafeName(SnakeToUpperCamelCase(oneof_field.name()));
335 }
336
CamelToSnakeCase(absl::string_view input)337 std::string CamelToSnakeCase(absl::string_view input) {
338 std::string result;
339 result.reserve(input.size() + 4); // No reallocation for 4 _
340 bool is_first_character = true;
341 bool last_char_was_underscore = false;
342 for (const char c : input) {
343 if (!is_first_character && absl::ascii_isupper(c) &&
344 !last_char_was_underscore) {
345 result += '_';
346 }
347 last_char_was_underscore = c == '_';
348 result += absl::ascii_tolower(c);
349 is_first_character = false;
350 }
351 return result;
352 }
353
SnakeToUpperCamelCase(absl::string_view input)354 std::string SnakeToUpperCamelCase(absl::string_view input) {
355 return cpp::UnderscoresToCamelCase(input, /*cap first letter=*/true);
356 }
357
ScreamingSnakeToUpperCamelCase(absl::string_view input)358 std::string ScreamingSnakeToUpperCamelCase(absl::string_view input) {
359 std::string result;
360 result.reserve(input.size());
361 bool cap_next_letter = true;
362 for (const char c : input) {
363 if (absl::ascii_isalpha(c)) {
364 if (cap_next_letter) {
365 result += absl::ascii_toupper(c);
366 } else {
367 result += absl::ascii_tolower(c);
368 }
369 cap_next_letter = false;
370 } else if (absl::ascii_isdigit(c)) {
371 result += c;
372 cap_next_letter = true;
373 } else {
374 cap_next_letter = true;
375 }
376 }
377 return result;
378 }
379
MultiCasePrefixStripper(absl::string_view prefix)380 MultiCasePrefixStripper::MultiCasePrefixStripper(absl::string_view prefix)
381 : prefixes_{
382 std::string(prefix),
383 ScreamingSnakeToUpperCamelCase(prefix),
384 CamelToSnakeCase(prefix),
385 } {}
386
StripPrefix(absl::string_view name) const387 absl::string_view MultiCasePrefixStripper::StripPrefix(
388 absl::string_view name) const {
389 absl::string_view start_name = name;
390 for (absl::string_view prefix : prefixes_) {
391 if (absl::StartsWithIgnoreCase(name, prefix)) {
392 name.remove_prefix(prefix.size());
393
394 // Also strip a joining underscore, if present.
395 absl::ConsumePrefix(&name, "_");
396
397 // Only strip one prefix.
398 break;
399 }
400 }
401
402 if (name.empty()) {
403 return start_name;
404 }
405 return name;
406 }
407
408 PROTOBUF_CONSTINIT const MapKeyType kMapKeyTypes[] = {
409 {/*thunk_ident=*/"i32", /*rs_key_t=*/"i32", /*rs_ffi_key_t=*/"i32",
410 /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
411 /*cc_key_t=*/"int32_t", /*cc_ffi_key_t=*/"int32_t",
412 /*cc_from_ffi_key_expr=*/"key",
413 /*cc_to_ffi_key_expr=*/"cpp_key"},
414 {/*thunk_ident=*/"u32", /*rs_key_t=*/"u32", /*rs_ffi_key_t=*/"u32",
415 /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
416 /*cc_key_t=*/"uint32_t", /*cc_ffi_key_t=*/"uint32_t",
417 /*cc_from_ffi_key_expr=*/"key",
418 /*cc_to_ffi_key_expr=*/"cpp_key"},
419 {/*thunk_ident=*/"i64", /*rs_key_t=*/"i64", /*rs_ffi_key_t=*/"i64",
420 /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
421 /*cc_key_t=*/"int64_t", /*cc_ffi_key_t=*/"int64_t",
422 /*cc_from_ffi_key_expr=*/"key",
423 /*cc_to_ffi_key_expr=*/"cpp_key"},
424 {/*thunk_ident=*/"u64", /*rs_key_t=*/"u64", /*rs_ffi_key_t=*/"u64",
425 /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
426 /*cc_key_t=*/"uint64_t", /*cc_ffi_key_t=*/"uint64_t",
427 /*cc_from_ffi_key_expr=*/"key",
428 /*cc_to_ffi_key_expr=*/"cpp_key"},
429 {/*thunk_ident=*/"bool", /*rs_key_t=*/"bool", /*rs_ffi_key_t=*/"bool",
430 /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
431 /*cc_key_t=*/"bool", /*cc_ffi_key_t=*/"bool",
432 /*cc_from_ffi_key_expr=*/"key",
433 /*cc_to_ffi_key_expr=*/"cpp_key"},
434 {/*thunk_ident=*/"ProtoString",
435 /*rs_key_t=*/"$pb$::ProtoString",
436 /*rs_ffi_key_t=*/"$pbr$::PtrAndLen",
437 /*rs_to_ffi_key_expr=*/"key.as_bytes().into()",
438 /*rs_from_ffi_key_expr=*/
439 "$pb$::ProtoStr::from_utf8_unchecked(ffi_key.as_ref())",
440 /*cc_key_t=*/"std::string",
441 /*cc_ffi_key_t=*/"google::protobuf::rust::PtrAndLen",
442 /*cc_from_ffi_key_expr=*/
443 "std::string(key.ptr, key.len)", /*cc_to_ffi_key_expr=*/
444 "google::protobuf::rust::PtrAndLen{cpp_key.data(), cpp_key.size()}"}};
445
446 } // namespace rust
447 } // namespace compiler
448 } // namespace protobuf
449 } // namespace google
450
451 #include "google/protobuf/port_undef.inc"
452