1 //! A process wrapper for running a Protobuf compiler configured for Prost or Tonic output in a Bazel rule.
2
3 use std::collections::BTreeMap;
4 use std::collections::BTreeSet;
5 use std::fmt::{Display, Formatter, Write};
6 use std::fs;
7 use std::io::BufRead;
8 use std::path::Path;
9 use std::path::PathBuf;
10 use std::process;
11 use std::{env, fmt};
12
13 use heck::{ToSnakeCase, ToUpperCamelCase};
14 use prost::Message;
15 use prost_types::{
16 DescriptorProto, EnumDescriptorProto, FileDescriptorProto, FileDescriptorSet,
17 OneofDescriptorProto,
18 };
19
20 /// Locate prost outputs in the protoc output directory.
find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf>21 fn find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf> {
22 let mut all_rs_files: BTreeSet<PathBuf> = BTreeSet::new();
23 for entry in fs::read_dir(out_dir).expect("Failed to read directory") {
24 let entry = entry.expect("Failed to read entry");
25 let path = entry.path();
26 if path.is_dir() {
27 for f in find_generated_rust_files(&path) {
28 all_rs_files.insert(f);
29 }
30 } else if let Some(ext) = path.extension() {
31 if ext == "rs" {
32 all_rs_files.insert(path);
33 }
34 } else if let Some(name) = path.file_name() {
35 // The filename is set to `_` when the package name is empty.
36 if name == "_" {
37 let rs_name = path.parent().expect("Failed to get parent").join("_.rs");
38 fs::rename(&path, &rs_name).unwrap_or_else(|err| {
39 panic!("Failed to rename file: {err:?}: {path:?} -> {rs_name:?}")
40 });
41 all_rs_files.insert(rs_name);
42 }
43 }
44 }
45
46 all_rs_files
47 }
48
snake_cased_package_name(package: &str) -> String49 fn snake_cased_package_name(package: &str) -> String {
50 if package == "_" {
51 return package.to_owned();
52 }
53
54 package
55 .split('.')
56 .map(|s| s.to_snake_case())
57 .collect::<Vec<_>>()
58 .join(".")
59 }
60
61 /// Rust module definition.
62 #[derive(Debug, Default)]
63 struct Module {
64 /// The name of the module.
65 name: String,
66
67 /// The contents of the module.
68 contents: String,
69
70 /// The names of any other modules which are submodules of this module.
71 submodules: BTreeSet<String>,
72 }
73
74 /// Generate a lib.rs file with all prost/tonic outputs embeeded in modules which
75 /// mirror the proto packages. For the example proto file we would expect to see
76 /// the Rust output that follows it.
77 ///
78 /// ```proto
79 /// syntax = "proto3";
80 /// package examples.prost.helloworld;
81 ///
82 /// message HelloRequest {
83 /// // Request message contains the name to be greeted
84 /// string name = 1;
85 /// }
86 //
87 /// message HelloReply {
88 /// // Reply contains the greeting message
89 /// string message = 1;
90 /// }
91 /// ```
92 ///
93 /// This is expected to render out to something like the following. Note that
94 /// formatting is not applied so indentation may be missing in the actual output.
95 ///
96 /// ```ignore
97 /// pub mod examples {
98 /// pub mod prost {
99 /// pub mod helloworld {
100 /// // @generated
101 /// #[allow(clippy::derive_partial_eq_without_eq)]
102 /// #[derive(Clone, PartialEq, ::prost::Message)]
103 /// pub struct HelloRequest {
104 /// /// Request message contains the name to be greeted
105 /// #[prost(string, tag = "1")]
106 /// pub name: ::prost::alloc::string::String,
107 /// }
108 /// #[allow(clippy::derive_partial_eq_without_eq)]
109 /// #[derive(Clone, PartialEq, ::prost::Message)]
110 /// pub struct HelloReply {
111 /// /// Reply contains the greeting message
112 /// #[prost(string, tag = "1")]
113 /// pub message: ::prost::alloc::string::String,
114 /// }
115 /// // @protoc_insertion_point(module)
116 /// }
117 /// }
118 /// }
119 /// ```
generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String120 fn generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String {
121 let mut module_info = BTreeMap::new();
122
123 for path in prost_outputs.iter() {
124 let mut package = path
125 .file_stem()
126 .expect("Failed to get file stem")
127 .to_str()
128 .expect("Failed to convert to str")
129 .to_string();
130
131 if is_tonic {
132 package = package
133 .strip_suffix(".tonic")
134 .expect("Failed to strip suffix")
135 .to_string()
136 };
137
138 if package.is_empty() {
139 continue;
140 }
141
142 let name = if package == "_" {
143 package.clone()
144 } else if package.contains('.') {
145 package
146 .rsplit_once('.')
147 .expect("Failed to split on '.'")
148 .1
149 .to_snake_case()
150 .to_string()
151 } else {
152 package.to_snake_case()
153 };
154
155 // Avoid a stack overflow by skipping a known bad package name
156 let module_name = snake_cased_package_name(&package);
157
158 module_info.insert(
159 module_name.clone(),
160 Module {
161 name,
162 contents: fs::read_to_string(path).expect("Failed to read file"),
163 submodules: BTreeSet::new(),
164 },
165 );
166
167 let module_parts = module_name.split('.').collect::<Vec<_>>();
168 for parent_module_index in 0..module_parts.len() {
169 let child_module_index = parent_module_index + 1;
170 if child_module_index >= module_parts.len() {
171 break;
172 }
173 let full_parent_module_name = module_parts[0..parent_module_index + 1].join(".");
174 let parent_module_name = module_parts[parent_module_index];
175 let child_module_name = module_parts[child_module_index];
176
177 module_info
178 .entry(full_parent_module_name.clone())
179 .and_modify(|parent_module| {
180 parent_module
181 .submodules
182 .insert(child_module_name.to_string());
183 })
184 .or_insert(Module {
185 name: parent_module_name.to_string(),
186 contents: "".to_string(),
187 submodules: [child_module_name.to_string()].iter().cloned().collect(),
188 });
189 }
190 }
191
192 let mut content = "// @generated\n\n".to_string();
193 write_module(&mut content, &module_info, "", 0);
194 content
195 }
196
197 /// Write out a rust module and all of its submodules.
write_module( content: &mut String, module_info: &BTreeMap<String, Module>, module_name: &str, depth: usize, )198 fn write_module(
199 content: &mut String,
200 module_info: &BTreeMap<String, Module>,
201 module_name: &str,
202 depth: usize,
203 ) {
204 if module_name.is_empty() {
205 for submodule_name in module_info.keys() {
206 write_module(content, module_info, submodule_name, depth + 1);
207 }
208 return;
209 }
210 let module = module_info.get(module_name).expect("Failed to get module");
211 let indent = " ".repeat(depth);
212 let is_rust_module = module.name != "_";
213
214 if is_rust_module {
215 let rust_module_name = escape_keyword(module.name.clone());
216 content
217 .write_str(&format!("{}pub mod {} {{\n", indent, rust_module_name))
218 .expect("Failed to write string");
219 }
220
221 content
222 .write_str(&module.contents)
223 .expect("Failed to write string");
224
225 for submodule_name in module.submodules.iter() {
226 write_module(
227 content,
228 module_info,
229 [module_name, submodule_name].join(".").as_str(),
230 depth + 1,
231 );
232 }
233
234 if is_rust_module {
235 content
236 .write_str(&format!("{}}}\n", indent))
237 .expect("Failed to write string");
238 }
239 }
240
241 /// ProtoPath is a path to a proto message, enum, or oneof.
242 ///
243 /// Example: `helloworld.Greeter.HelloRequest`
244 #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
245 struct ProtoPath(String);
246
247 impl ProtoPath {
248 /// Join a component to the end of the path.
join(&self, component: &str) -> ProtoPath249 fn join(&self, component: &str) -> ProtoPath {
250 if self.0.is_empty() {
251 return ProtoPath(component.to_string());
252 }
253 if component.is_empty() {
254 return self.clone();
255 }
256
257 ProtoPath(format!("{}.{}", self.0, component))
258 }
259 }
260
261 impl Display for ProtoPath {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result262 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
263 write!(f, "{}", self.0)
264 }
265 }
266
267 impl From<&str> for ProtoPath {
from(path: &str) -> Self268 fn from(path: &str) -> Self {
269 ProtoPath(path.to_string())
270 }
271 }
272
273 /// RustModulePath is a path to a rust module.
274 ///
275 /// Example: `helloworld::greeter::HelloRequest`
276 #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
277 struct RustModulePath(String);
278
279 impl RustModulePath {
280 /// Join a path to the end of the module path.
join(&self, path: &str) -> RustModulePath281 fn join(&self, path: &str) -> RustModulePath {
282 if self.0.is_empty() {
283 return RustModulePath(path.to_string());
284 }
285 if path.is_empty() {
286 return self.clone();
287 }
288
289 RustModulePath(format!("{}::{}", self.0, path))
290 }
291 }
292
293 impl Display for RustModulePath {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result294 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
295 write!(f, "{}", self.0)
296 }
297 }
298
299 impl From<&str> for RustModulePath {
from(path: &str) -> Self300 fn from(path: &str) -> Self {
301 RustModulePath(path.to_string())
302 }
303 }
304
305 /// Compute the `--extern_path` flags for a list of proto files. This is
306 /// expected to convert proto files into a BTreeMap of
307 /// `example.prost.helloworld`: `crate_name::example::prost::helloworld`.
get_extern_paths( descriptor_set: &FileDescriptorSet, crate_name: &str, ) -> Result<BTreeMap<ProtoPath, RustModulePath>, String>308 fn get_extern_paths(
309 descriptor_set: &FileDescriptorSet,
310 crate_name: &str,
311 ) -> Result<BTreeMap<ProtoPath, RustModulePath>, String> {
312 let mut extern_paths = BTreeMap::new();
313 let rust_path = RustModulePath(crate_name.to_string());
314
315 for file in descriptor_set.file.iter() {
316 descriptor_set_file_to_extern_paths(&mut extern_paths, &rust_path, file);
317 }
318
319 Ok(extern_paths)
320 }
321
322 /// Add the extern_path pairs for a file descriptor type.
descriptor_set_file_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, rust_path: &RustModulePath, file: &FileDescriptorProto, )323 fn descriptor_set_file_to_extern_paths(
324 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
325 rust_path: &RustModulePath,
326 file: &FileDescriptorProto,
327 ) {
328 let package = file.package.clone().unwrap_or_default();
329 let rust_path = rust_path.join(&snake_cased_package_name(&package).replace('.', "::"));
330 let proto_path = ProtoPath(package);
331
332 for message_type in file.message_type.iter() {
333 message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, message_type);
334 }
335
336 for enum_type in file.enum_type.iter() {
337 enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
338 }
339 }
340
341 /// Add the extern_path pairs for a message descriptor type.
message_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, message_type: &DescriptorProto, )342 fn message_type_to_extern_paths(
343 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
344 proto_path: &ProtoPath,
345 rust_path: &RustModulePath,
346 message_type: &DescriptorProto,
347 ) {
348 let message_type_name = message_type
349 .name
350 .as_ref()
351 .expect("Failed to get message type name");
352
353 extern_paths.insert(
354 proto_path.join(message_type_name),
355 rust_path.join(&message_type_name.to_upper_camel_case()),
356 );
357
358 let name_lower = message_type_name.to_lowercase();
359 let proto_path = proto_path.join(&name_lower);
360 let rust_path = rust_path.join(&name_lower);
361
362 for nested_type in message_type.nested_type.iter() {
363 message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, nested_type)
364 }
365
366 for enum_type in message_type.enum_type.iter() {
367 enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
368 }
369
370 for oneof_type in message_type.oneof_decl.iter() {
371 oneof_type_to_extern_paths(extern_paths, &proto_path, &rust_path, oneof_type);
372 }
373 }
374
375 /// Add the extern_path pairs for an enum type.
enum_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, enum_type: &EnumDescriptorProto, )376 fn enum_type_to_extern_paths(
377 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
378 proto_path: &ProtoPath,
379 rust_path: &RustModulePath,
380 enum_type: &EnumDescriptorProto,
381 ) {
382 let enum_type_name = enum_type
383 .name
384 .as_ref()
385 .expect("Failed to get enum type name");
386 extern_paths.insert(
387 proto_path.join(enum_type_name),
388 rust_path.join(enum_type_name),
389 );
390 }
391
oneof_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, oneof_type: &OneofDescriptorProto, )392 fn oneof_type_to_extern_paths(
393 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
394 proto_path: &ProtoPath,
395 rust_path: &RustModulePath,
396 oneof_type: &OneofDescriptorProto,
397 ) {
398 let oneof_type_name = oneof_type
399 .name
400 .as_ref()
401 .expect("Failed to get oneof type name");
402 extern_paths.insert(
403 proto_path.join(oneof_type_name),
404 rust_path.join(oneof_type_name),
405 );
406 }
407
408 /// The parsed command-line arguments.
409 struct Args {
410 /// The path to the protoc binary.
411 protoc: PathBuf,
412
413 /// The path to the output directory.
414 out_dir: PathBuf,
415
416 /// The name of the crate.
417 crate_name: String,
418
419 /// The bazel label.
420 label: String,
421
422 /// The path to the package info file.
423 package_info_file: PathBuf,
424
425 /// The proto files to compile.
426 proto_files: Vec<PathBuf>,
427
428 /// The include directories.
429 includes: Vec<String>,
430
431 /// Dependency descriptor sets.
432 descriptor_set: PathBuf,
433
434 /// The path to the generated lib.rs file.
435 out_librs: PathBuf,
436
437 /// The proto include paths.
438 proto_paths: Vec<String>,
439
440 /// The path to the rustfmt binary.
441 rustfmt: Option<PathBuf>,
442
443 /// Whether to generate tonic code.
444 is_tonic: bool,
445
446 /// Extra arguments to pass to protoc.
447 extra_args: Vec<String>,
448 }
449
450 impl Args {
451 /// Parse the command-line arguments.
parse() -> Result<Args, String>452 fn parse() -> Result<Args, String> {
453 let mut protoc: Option<PathBuf> = None;
454 let mut out_dir: Option<PathBuf> = None;
455 let mut crate_name: Option<String> = None;
456 let mut package_info_file: Option<PathBuf> = None;
457 let mut proto_files: Vec<PathBuf> = Vec::new();
458 let mut includes = Vec::new();
459 let mut descriptor_set = None;
460 let mut out_librs: Option<PathBuf> = None;
461 let mut rustfmt: Option<PathBuf> = None;
462 let mut proto_paths = Vec::new();
463 let mut label: Option<String> = None;
464 let mut tonic_or_prost_opts = Vec::new();
465 let mut is_tonic = false;
466
467 let mut extra_args = Vec::new();
468
469 let mut handle_arg = |arg: String| {
470 if !arg.starts_with('-') {
471 proto_files.push(PathBuf::from(arg));
472 return;
473 }
474
475 if arg.starts_with("-I") {
476 includes.push(
477 arg.strip_prefix("-I")
478 .expect("Failed to strip -I")
479 .to_string(),
480 );
481 return;
482 }
483
484 if arg == "--is_tonic" {
485 is_tonic = true;
486 return;
487 }
488
489 if !arg.contains('=') {
490 extra_args.push(arg);
491 return;
492 }
493
494 let parts = arg.split_once('=').expect("Failed to split argument on =");
495 match parts {
496 ("--protoc", value) => {
497 protoc = Some(PathBuf::from(value));
498 }
499 ("--prost_out", value) => {
500 out_dir = Some(PathBuf::from(value));
501 }
502 ("--package_info_output", value) => {
503 let (key, value) = value
504 .split_once('=')
505 .map(|(a, b)| (a.to_string(), PathBuf::from(b)))
506 .expect("Failed to parse package info output");
507 crate_name = Some(key);
508 package_info_file = Some(value);
509 }
510 ("--deps_info", value) => {
511 for line in fs::read_to_string(value)
512 .expect("Failed to read file")
513 .lines()
514 {
515 let path = PathBuf::from(line.trim());
516 for flag in fs::read_to_string(path)
517 .expect("Failed to read file")
518 .lines()
519 {
520 tonic_or_prost_opts.push(format!("extern_path={}", flag.trim()));
521 }
522 }
523 }
524 ("--descriptor_set", value) => {
525 descriptor_set = Some(PathBuf::from(value));
526 }
527 ("--out_librs", value) => {
528 out_librs = Some(PathBuf::from(value));
529 }
530 ("--rustfmt", value) => {
531 rustfmt = Some(PathBuf::from(value));
532 }
533 ("--proto_path", value) => {
534 proto_paths.push(value.to_string());
535 }
536 ("--label", value) => {
537 label = Some(value.to_string());
538 }
539 (arg, value) => {
540 extra_args.push(format!("{}={}", arg, value));
541 }
542 }
543 };
544
545 // Iterate over the given command line arguments parsing out arguments
546 // for the process runner and arguments for protoc and potentially spawn
547 // additional arguments needed by prost.
548 for arg in env::args().skip(1) {
549 if let Some(path) = arg.strip_prefix('@') {
550 // handle argfile
551 let file = std::fs::File::open(path)
552 .map_err(|_| format!("could not open argfile: {}", arg))?;
553 for line in std::io::BufReader::new(file).lines() {
554 handle_arg(line.map_err(|_| format!("could not read argfile: {}", arg))?);
555 }
556 } else {
557 handle_arg(arg);
558 }
559 }
560
561 for tonic_or_prost_opt in tonic_or_prost_opts {
562 extra_args.push(format!("--prost_opt={}", tonic_or_prost_opt));
563 if is_tonic {
564 extra_args.push(format!("--tonic_opt={}", tonic_or_prost_opt));
565 }
566 }
567
568 if protoc.is_none() {
569 return Err(
570 "No `--protoc` value was found. Unable to parse path to proto compiler."
571 .to_string(),
572 );
573 }
574 if out_dir.is_none() {
575 return Err(
576 "No `--prost_out` value was found. Unable to parse output directory.".to_string(),
577 );
578 }
579 if crate_name.is_none() {
580 return Err(
581 "No `--package_info_output` value was found. Unable to parse target crate name."
582 .to_string(),
583 );
584 }
585 if package_info_file.is_none() {
586 return Err("No `--package_info_output` value was found. Unable to parse package info output file.".to_string());
587 }
588 if out_librs.is_none() {
589 return Err("No `--out_librs` value was found. Unable to parse the output location for all combined prost outputs.".to_string());
590 }
591 if descriptor_set.is_none() {
592 return Err(
593 "No `--descriptor_set` value was found. Unable to parse descriptor set path."
594 .to_string(),
595 );
596 }
597 if label.is_none() {
598 return Err(
599 "No `--label` value was found. Unable to parse the label of the target crate."
600 .to_string(),
601 );
602 }
603
604 Ok(Args {
605 protoc: protoc.unwrap(),
606 out_dir: out_dir.unwrap(),
607 crate_name: crate_name.unwrap(),
608 package_info_file: package_info_file.unwrap(),
609 proto_files,
610 includes,
611 descriptor_set: descriptor_set.unwrap(),
612 out_librs: out_librs.unwrap(),
613 rustfmt,
614 proto_paths,
615 is_tonic,
616 label: label.unwrap(),
617 extra_args,
618 })
619 }
620 }
621
622 /// Get the output directory with the label suffixed.
get_output_dir(out_dir: &Path, label: &str) -> PathBuf623 fn get_output_dir(out_dir: &Path, label: &str) -> PathBuf {
624 let label_as_path = label
625 .replace('@', "")
626 .replace("//", "_")
627 .replace(['/', ':'], "_");
628 PathBuf::from(format!(
629 "{}/prost-build-{}",
630 out_dir.display(),
631 label_as_path
632 ))
633 }
634
635 /// Get the output directory with the label suffixed, and create it if it doesn't exist.
636 ///
637 /// This will remove the directory first if it already exists.
get_and_create_output_dir(out_dir: &Path, label: &str) -> PathBuf638 fn get_and_create_output_dir(out_dir: &Path, label: &str) -> PathBuf {
639 let out_dir = get_output_dir(out_dir, label);
640 if out_dir.exists() {
641 fs::remove_dir_all(&out_dir).expect("Failed to remove old output directory");
642 }
643 fs::create_dir_all(&out_dir).expect("Failed to create output directory");
644 out_dir
645 }
646
647 /// Parse the descriptor set file into a `FileDescriptorSet`.
parse_descriptor_set_file(descriptor_set_path: &PathBuf) -> FileDescriptorSet648 fn parse_descriptor_set_file(descriptor_set_path: &PathBuf) -> FileDescriptorSet {
649 let descriptor_set_bytes =
650 fs::read(descriptor_set_path).expect("Failed to read descriptor set");
651 let descriptor_set = FileDescriptorSet::decode(descriptor_set_bytes.as_slice())
652 .expect("Failed to decode descriptor set");
653
654 descriptor_set
655 }
656
657 /// Get the package name from the descriptor set.
get_package_name(descriptor_set: &FileDescriptorSet) -> Option<String>658 fn get_package_name(descriptor_set: &FileDescriptorSet) -> Option<String> {
659 let mut package_name = None;
660
661 for file in &descriptor_set.file {
662 if let Some(package) = &file.package {
663 package_name = Some(package.clone());
664 break;
665 }
666 }
667
668 package_name
669 }
670
671 /// Whether the proto file should expect to generate a .rs file.
672 ///
673 /// If the proto file contains any messages, enums, or services, then it should generate a rust file.
674 /// If the proto file only contains extensions, then it will not generate any rust files.
expect_fs_file_to_be_generated(descriptor_set: &FileDescriptorSet) -> bool675 fn expect_fs_file_to_be_generated(descriptor_set: &FileDescriptorSet) -> bool {
676 let mut expect_rs = false;
677
678 for file in descriptor_set.file.iter() {
679 let has_messages = !file.message_type.is_empty();
680 let has_enums = !file.enum_type.is_empty();
681 let has_services = !file.service.is_empty();
682 let has_extensions = !file.extension.is_empty();
683
684 let has_definition = has_messages || has_enums || has_services;
685
686 if has_definition {
687 return true;
688 } else if !has_definition && !has_extensions {
689 expect_rs = true;
690 }
691 }
692
693 expect_rs
694 }
695
696 /// Whether the proto file should expect to generate service definitions.
has_services(descriptor_set: &FileDescriptorSet) -> bool697 fn has_services(descriptor_set: &FileDescriptorSet) -> bool {
698 descriptor_set
699 .file
700 .iter()
701 .any(|file| !file.service.is_empty())
702 }
703
main()704 fn main() {
705 // Always enable backtraces for the protoc wrapper.
706 env::set_var("RUST_BACKTRACE", "1");
707
708 let Args {
709 protoc,
710 out_dir,
711 crate_name,
712 label,
713 package_info_file,
714 proto_files,
715 includes,
716 descriptor_set,
717 out_librs,
718 rustfmt,
719 proto_paths,
720 is_tonic,
721 extra_args,
722 } = Args::parse().expect("Failed to parse args");
723
724 let out_dir = get_and_create_output_dir(&out_dir, &label);
725
726 let descriptor_set = parse_descriptor_set_file(&descriptor_set);
727 let package_name = get_package_name(&descriptor_set).unwrap_or_default();
728 let expect_rs = expect_fs_file_to_be_generated(&descriptor_set);
729 let has_services = has_services(&descriptor_set);
730
731 if has_services && !is_tonic {
732 eprintln!("Warning: Service definitions will not be generated because the prost toolchain did not define a tonic plugin.");
733 }
734
735 let mut cmd = process::Command::new(protoc);
736 cmd.arg(format!("--prost_out={}", out_dir.display()));
737 if is_tonic {
738 cmd.arg(format!("--tonic_out={}", out_dir.display()));
739 }
740 cmd.args(extra_args);
741 cmd.args(
742 proto_paths
743 .iter()
744 .map(|proto_path| format!("--proto_path={}", proto_path)),
745 );
746 cmd.args(includes.iter().map(|include| format!("-I{}", include)));
747 cmd.args(&proto_files);
748
749 let status = cmd.status().expect("Failed to spawn protoc process");
750 if !status.success() {
751 panic!(
752 "protoc failed with status: {}",
753 status.code().expect("failed to get exit code")
754 );
755 }
756
757 // Not all proto files will consistently produce `.rs` or `.tonic.rs` files. This is
758 // caused by the proto file being transpiled not having an RPC service or other protos
759 // defined (a natural and expected situation). To guarantee consistent outputs, all
760 // `.rs` files are either renamed to `.tonic.rs` if there is no `.tonic.rs` or prepended
761 // to the existing `.tonic.rs`.
762 if is_tonic {
763 let tonic_files: BTreeSet<PathBuf> = find_generated_rust_files(&out_dir);
764
765 for tonic_file in tonic_files.iter() {
766 let tonic_path_str = tonic_file.to_str().expect("Failed to convert to str");
767 let filename = tonic_file
768 .file_name()
769 .expect("Failed to get file name")
770 .to_str()
771 .expect("Failed to convert to str");
772
773 let is_tonic_file = filename.ends_with(".tonic.rs");
774
775 if is_tonic_file {
776 let rs_file_str = format!(
777 "{}.rs",
778 tonic_path_str
779 .strip_suffix(".tonic.rs")
780 .expect("Failed to strip suffix.")
781 );
782 let rs_file = PathBuf::from(&rs_file_str);
783
784 if rs_file.exists() {
785 let rs_content = fs::read_to_string(&rs_file).expect("Failed to read file.");
786 let tonic_content =
787 fs::read_to_string(tonic_file).expect("Failed to read file.");
788 fs::write(tonic_file, format!("{}\n{}", rs_content, tonic_content))
789 .expect("Failed to write file.");
790 fs::remove_file(&rs_file).unwrap_or_else(|err| {
791 panic!("Failed to remove file: {err:?}: {rs_file:?}")
792 });
793 }
794 } else {
795 let real_tonic_file = PathBuf::from(format!(
796 "{}.tonic.rs",
797 tonic_path_str
798 .strip_suffix(".rs")
799 .expect("Failed to strip suffix.")
800 ));
801 if real_tonic_file.exists() {
802 continue;
803 }
804 fs::rename(tonic_file, &real_tonic_file).unwrap_or_else(|err| {
805 panic!("Failed to rename file: {err:?}: {tonic_file:?} -> {real_tonic_file:?}");
806 });
807 }
808 }
809 }
810
811 // Locate all prost-generated outputs.
812 let mut rust_files = find_generated_rust_files(&out_dir);
813 if rust_files.is_empty() {
814 if expect_rs {
815 panic!("No .rs files were generated by prost.");
816 } else {
817 let file_stem = if package_name.is_empty() {
818 "_"
819 } else {
820 &package_name
821 };
822 let file_stem = format!("{}{}", file_stem, if is_tonic { ".tonic" } else { "" });
823 let empty_rs_file = out_dir.join(format!("{}.rs", file_stem));
824 fs::write(&empty_rs_file, "").expect("Failed to write file.");
825 rust_files.insert(empty_rs_file);
826 }
827 }
828
829 let extern_paths = get_extern_paths(&descriptor_set, &crate_name)
830 .expect("Failed to compute proto package info");
831
832 // Write outputs
833 fs::write(&out_librs, generate_lib_rs(&rust_files, is_tonic)).expect("Failed to write file.");
834 fs::write(
835 package_info_file,
836 extern_paths
837 .into_iter()
838 .map(|(proto_path, rust_path)| format!(".{}=::{}", proto_path, rust_path))
839 .collect::<Vec<_>>()
840 .join("\n"),
841 )
842 .expect("Failed to write file.");
843
844 // Finally run rustfmt on the output lib.rs file
845 if let Some(rustfmt) = rustfmt {
846 let fmt_status = process::Command::new(rustfmt)
847 .arg("--edition")
848 .arg("2021")
849 .arg("--quiet")
850 .arg(&out_librs)
851 .status()
852 .expect("Failed to spawn rustfmt process");
853 if !fmt_status.success() {
854 panic!(
855 "rustfmt failed with exit code: {}",
856 fmt_status.code().expect("Failed to get exit code")
857 );
858 }
859 }
860 }
861
862 /// Rust built-in keywords and reserved keywords.
863 const RUST_KEYWORDS: [&str; 51] = [
864 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate",
865 "do", "dyn", "else", "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in",
866 "let", "loop", "macro", "match", "mod", "move", "mut", "override", "priv", "pub", "ref",
867 "return", "self", "Self", "static", "struct", "super", "trait", "true", "try", "type",
868 "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
869 ];
870
871 /// Returns true if the given string is a Rust keyword.
is_keyword(s: &str) -> bool872 fn is_keyword(s: &str) -> bool {
873 RUST_KEYWORDS.contains(&s)
874 }
875
876 /// Escapes a Rust keyword by prefixing it with `r#`.
escape_keyword(s: String) -> String877 fn escape_keyword(s: String) -> String {
878 if is_keyword(&s) {
879 return format!("r#{s}");
880 }
881 s
882 }
883
884 #[cfg(test)]
885 mod test {
886
887 use super::*;
888
889 use prost_types::{FieldDescriptorProto, FileDescriptorProto, ServiceDescriptorProto};
890 use std::collections::BTreeMap;
891
892 #[test]
oneof_type_to_extern_paths_test()893 fn oneof_type_to_extern_paths_test() {
894 let oneof_descriptor = OneofDescriptorProto {
895 name: Some("Foo".to_string()),
896 ..OneofDescriptorProto::default()
897 };
898
899 {
900 let mut extern_paths = BTreeMap::new();
901 oneof_type_to_extern_paths(
902 &mut extern_paths,
903 &ProtoPath::from("bar"),
904 &RustModulePath::from("bar"),
905 &oneof_descriptor,
906 );
907
908 assert_eq!(extern_paths.len(), 1);
909 assert_eq!(
910 extern_paths.get(&ProtoPath::from("bar.Foo")),
911 Some(&RustModulePath::from("bar::Foo"))
912 );
913 }
914
915 {
916 let mut extern_paths = BTreeMap::new();
917 oneof_type_to_extern_paths(
918 &mut extern_paths,
919 &ProtoPath::from("bar.baz"),
920 &RustModulePath::from("bar::baz"),
921 &oneof_descriptor,
922 );
923
924 assert_eq!(extern_paths.len(), 1);
925 assert_eq!(
926 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
927 Some(&RustModulePath::from("bar::baz::Foo"))
928 );
929 }
930 }
931
932 #[test]
enum_type_to_extern_paths_test()933 fn enum_type_to_extern_paths_test() {
934 let enum_descriptor = EnumDescriptorProto {
935 name: Some("Foo".to_string()),
936 ..EnumDescriptorProto::default()
937 };
938
939 {
940 let mut extern_paths = BTreeMap::new();
941 enum_type_to_extern_paths(
942 &mut extern_paths,
943 &ProtoPath::from("bar"),
944 &RustModulePath::from("bar"),
945 &enum_descriptor,
946 );
947
948 assert_eq!(extern_paths.len(), 1);
949 assert_eq!(
950 extern_paths.get(&ProtoPath::from("bar.Foo")),
951 Some(&RustModulePath::from("bar::Foo"))
952 );
953 }
954
955 {
956 let mut extern_paths = BTreeMap::new();
957 enum_type_to_extern_paths(
958 &mut extern_paths,
959 &ProtoPath::from("bar.baz"),
960 &RustModulePath::from("bar::baz"),
961 &enum_descriptor,
962 );
963
964 assert_eq!(extern_paths.len(), 1);
965 assert_eq!(
966 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
967 Some(&RustModulePath::from("bar::baz::Foo"))
968 );
969 }
970 }
971
972 #[test]
message_type_to_extern_paths_test()973 fn message_type_to_extern_paths_test() {
974 let message_descriptor = DescriptorProto {
975 name: Some("Foo".to_string()),
976 nested_type: vec![
977 DescriptorProto {
978 name: Some("Bar".to_string()),
979 ..DescriptorProto::default()
980 },
981 DescriptorProto {
982 name: Some("Nested".to_string()),
983 nested_type: vec![DescriptorProto {
984 name: Some("Baz".to_string()),
985 enum_type: vec![EnumDescriptorProto {
986 name: Some("Chuck".to_string()),
987 ..EnumDescriptorProto::default()
988 }],
989 ..DescriptorProto::default()
990 }],
991 ..DescriptorProto::default()
992 },
993 ],
994 enum_type: vec![EnumDescriptorProto {
995 name: Some("Qux".to_string()),
996 ..EnumDescriptorProto::default()
997 }],
998 ..DescriptorProto::default()
999 };
1000
1001 {
1002 let mut extern_paths = BTreeMap::new();
1003 message_type_to_extern_paths(
1004 &mut extern_paths,
1005 &ProtoPath::from("bar"),
1006 &RustModulePath::from("bar"),
1007 &message_descriptor,
1008 );
1009 assert_eq!(extern_paths.len(), 6);
1010 assert_eq!(
1011 extern_paths.get(&ProtoPath::from("bar.Foo")),
1012 Some(&RustModulePath::from("bar::Foo"))
1013 );
1014 assert_eq!(
1015 extern_paths.get(&ProtoPath::from("bar.foo.Bar")),
1016 Some(&RustModulePath::from("bar::foo::Bar"))
1017 );
1018 assert_eq!(
1019 extern_paths.get(&ProtoPath::from("bar.foo.Nested")),
1020 Some(&RustModulePath::from("bar::foo::Nested"))
1021 );
1022 assert_eq!(
1023 extern_paths.get(&ProtoPath::from("bar.foo.nested.Baz")),
1024 Some(&RustModulePath::from("bar::foo::nested::Baz"))
1025 );
1026 }
1027
1028 {
1029 let mut extern_paths = BTreeMap::new();
1030 message_type_to_extern_paths(
1031 &mut extern_paths,
1032 &ProtoPath::from("bar.bob"),
1033 &RustModulePath::from("bar::bob"),
1034 &message_descriptor,
1035 );
1036 assert_eq!(extern_paths.len(), 6);
1037 assert_eq!(
1038 extern_paths.get(&ProtoPath::from("bar.bob.Foo")),
1039 Some(&RustModulePath::from("bar::bob::Foo"))
1040 );
1041 assert_eq!(
1042 extern_paths.get(&ProtoPath::from("bar.bob.foo.Bar")),
1043 Some(&RustModulePath::from("bar::bob::foo::Bar"))
1044 );
1045 assert_eq!(
1046 extern_paths.get(&ProtoPath::from("bar.bob.foo.Nested")),
1047 Some(&RustModulePath::from("bar::bob::foo::Nested"))
1048 );
1049 assert_eq!(
1050 extern_paths.get(&ProtoPath::from("bar.bob.foo.nested.Baz")),
1051 Some(&RustModulePath::from("bar::bob::foo::nested::Baz"))
1052 );
1053 }
1054 }
1055
1056 #[test]
proto_path_test()1057 fn proto_path_test() {
1058 {
1059 let proto_path = ProtoPath::from("");
1060 assert_eq!(proto_path.to_string(), "");
1061 assert_eq!(proto_path.join("foo"), ProtoPath::from("foo"));
1062 }
1063 {
1064 let proto_path = ProtoPath::from("foo");
1065 assert_eq!(proto_path.to_string(), "foo");
1066 assert_eq!(proto_path.join(""), ProtoPath::from("foo"));
1067 }
1068 {
1069 let proto_path = ProtoPath::from("foo");
1070 assert_eq!(proto_path.to_string(), "foo");
1071 assert_eq!(proto_path.join("bar"), ProtoPath::from("foo.bar"));
1072 }
1073 {
1074 let proto_path = ProtoPath::from("foo.bar");
1075 assert_eq!(proto_path.to_string(), "foo.bar");
1076 assert_eq!(proto_path.join("baz"), ProtoPath::from("foo.bar.baz"));
1077 }
1078 {
1079 let proto_path = ProtoPath::from("Foo.baR");
1080 assert_eq!(proto_path.to_string(), "Foo.baR");
1081 assert_eq!(proto_path.join("baz"), ProtoPath::from("Foo.baR.baz"));
1082 }
1083 }
1084
1085 #[test]
rust_module_path_test()1086 fn rust_module_path_test() {
1087 {
1088 let rust_module_path = RustModulePath::from("");
1089 assert_eq!(rust_module_path.to_string(), "");
1090 assert_eq!(rust_module_path.join("foo"), RustModulePath::from("foo"));
1091 }
1092 {
1093 let rust_module_path = RustModulePath::from("foo");
1094 assert_eq!(rust_module_path.to_string(), "foo");
1095 assert_eq!(rust_module_path.join(""), RustModulePath::from("foo"));
1096 }
1097 {
1098 let rust_module_path = RustModulePath::from("foo");
1099 assert_eq!(rust_module_path.to_string(), "foo");
1100 assert_eq!(
1101 rust_module_path.join("bar"),
1102 RustModulePath::from("foo::bar")
1103 );
1104 }
1105 {
1106 let rust_module_path = RustModulePath::from("foo::bar");
1107 assert_eq!(rust_module_path.to_string(), "foo::bar");
1108 assert_eq!(
1109 rust_module_path.join("baz"),
1110 RustModulePath::from("foo::bar::baz")
1111 );
1112 }
1113 }
1114
1115 #[test]
expect_fs_file_to_be_generated_test()1116 fn expect_fs_file_to_be_generated_test() {
1117 {
1118 // Empty descriptor set should create a file.
1119 let descriptor_set = FileDescriptorSet {
1120 file: vec![FileDescriptorProto {
1121 name: Some("foo.proto".to_string()),
1122 ..FileDescriptorProto::default()
1123 }],
1124 };
1125 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1126 }
1127 {
1128 // Descriptor set with only message should create a file.
1129 let descriptor_set = FileDescriptorSet {
1130 file: vec![FileDescriptorProto {
1131 name: Some("foo.proto".to_string()),
1132 message_type: vec![DescriptorProto {
1133 name: Some("Foo".to_string()),
1134 ..DescriptorProto::default()
1135 }],
1136 ..FileDescriptorProto::default()
1137 }],
1138 };
1139 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1140 }
1141 {
1142 // Descriptor set with only enum should create a file.
1143 let descriptor_set = FileDescriptorSet {
1144 file: vec![FileDescriptorProto {
1145 name: Some("foo.proto".to_string()),
1146 enum_type: vec![EnumDescriptorProto {
1147 name: Some("Foo".to_string()),
1148 ..EnumDescriptorProto::default()
1149 }],
1150 ..FileDescriptorProto::default()
1151 }],
1152 };
1153 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1154 }
1155 {
1156 // Descriptor set with only service should create a file.
1157 let descriptor_set = FileDescriptorSet {
1158 file: vec![FileDescriptorProto {
1159 name: Some("foo.proto".to_string()),
1160 service: vec![ServiceDescriptorProto {
1161 name: Some("Foo".to_string()),
1162 ..ServiceDescriptorProto::default()
1163 }],
1164 ..FileDescriptorProto::default()
1165 }],
1166 };
1167 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1168 }
1169 {
1170 // Descriptor set with only extensions should not create a file.
1171 let descriptor_set = FileDescriptorSet {
1172 file: vec![FileDescriptorProto {
1173 name: Some("foo.proto".to_string()),
1174 extension: vec![FieldDescriptorProto {
1175 name: Some("Foo".to_string()),
1176 ..FieldDescriptorProto::default()
1177 }],
1178 ..FileDescriptorProto::default()
1179 }],
1180 };
1181 assert!(!expect_fs_file_to_be_generated(&descriptor_set));
1182 }
1183 }
1184
1185 #[test]
has_services_test()1186 fn has_services_test() {
1187 {
1188 // Empty file should not have services.
1189 let descriptor_set = FileDescriptorSet {
1190 file: vec![FileDescriptorProto {
1191 name: Some("foo.proto".to_string()),
1192 ..FileDescriptorProto::default()
1193 }],
1194 };
1195 assert!(!has_services(&descriptor_set));
1196 }
1197 {
1198 // File with only message should not have services.
1199 let descriptor_set = FileDescriptorSet {
1200 file: vec![FileDescriptorProto {
1201 name: Some("foo.proto".to_string()),
1202 message_type: vec![DescriptorProto {
1203 name: Some("Foo".to_string()),
1204 ..DescriptorProto::default()
1205 }],
1206 ..FileDescriptorProto::default()
1207 }],
1208 };
1209 assert!(!has_services(&descriptor_set));
1210 }
1211 {
1212 // File with services should have services.
1213 let descriptor_set = FileDescriptorSet {
1214 file: vec![FileDescriptorProto {
1215 name: Some("foo.proto".to_string()),
1216 service: vec![ServiceDescriptorProto {
1217 name: Some("Foo".to_string()),
1218 ..ServiceDescriptorProto::default()
1219 }],
1220 ..FileDescriptorProto::default()
1221 }],
1222 };
1223 assert!(has_services(&descriptor_set));
1224 }
1225 }
1226
1227 #[test]
get_package_name_test()1228 fn get_package_name_test() {
1229 let descriptor_set = FileDescriptorSet {
1230 file: vec![FileDescriptorProto {
1231 name: Some("foo.proto".to_string()),
1232 package: Some("foo".to_string()),
1233 ..FileDescriptorProto::default()
1234 }],
1235 };
1236
1237 assert_eq!(get_package_name(&descriptor_set), Some("foo".to_string()));
1238 }
1239
1240 #[test]
is_keyword_test()1241 fn is_keyword_test() {
1242 let non_keywords = [
1243 "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1244 "plugh", "xyzzy", "thud",
1245 ];
1246 for non_keyword in &non_keywords {
1247 assert!(!is_keyword(non_keyword));
1248 }
1249
1250 for keyword in &RUST_KEYWORDS {
1251 assert!(is_keyword(keyword));
1252 }
1253 }
1254
1255 #[test]
escape_keyword_test()1256 fn escape_keyword_test() {
1257 let non_keywords = [
1258 "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1259 "plugh", "xyzzy", "thud",
1260 ];
1261 for non_keyword in &non_keywords {
1262 assert_eq!(
1263 escape_keyword(non_keyword.to_string()),
1264 non_keyword.to_owned()
1265 );
1266 }
1267
1268 for keyword in &RUST_KEYWORDS {
1269 assert_eq!(
1270 escape_keyword(keyword.to_string()),
1271 format!("r#{}", keyword)
1272 );
1273 }
1274 }
1275 }
1276