• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! A process wrapper for running a Protobuf compiler configured for Prost or Tonic output in a Bazel rule.
2 
3 use std::collections::BTreeMap;
4 use std::collections::BTreeSet;
5 use std::fmt::{Display, Formatter, Write};
6 use std::fs;
7 use std::io::BufRead;
8 use std::path::Path;
9 use std::path::PathBuf;
10 use std::process;
11 use std::{env, fmt};
12 
13 use heck::{ToSnakeCase, ToUpperCamelCase};
14 use prost::Message;
15 use prost_types::{
16     DescriptorProto, EnumDescriptorProto, FileDescriptorProto, FileDescriptorSet,
17     OneofDescriptorProto,
18 };
19 
20 /// Locate prost outputs in the protoc output directory.
find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf>21 fn find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf> {
22     let mut all_rs_files: BTreeSet<PathBuf> = BTreeSet::new();
23     for entry in fs::read_dir(out_dir).expect("Failed to read directory") {
24         let entry = entry.expect("Failed to read entry");
25         let path = entry.path();
26         if path.is_dir() {
27             for f in find_generated_rust_files(&path) {
28                 all_rs_files.insert(f);
29             }
30         } else if let Some(ext) = path.extension() {
31             if ext == "rs" {
32                 all_rs_files.insert(path);
33             }
34         } else if let Some(name) = path.file_name() {
35             // The filename is set to `_` when the package name is empty.
36             if name == "_" {
37                 let rs_name = path.parent().expect("Failed to get parent").join("_.rs");
38                 fs::rename(&path, &rs_name).unwrap_or_else(|err| {
39                     panic!("Failed to rename file: {err:?}: {path:?} -> {rs_name:?}")
40                 });
41                 all_rs_files.insert(rs_name);
42             }
43         }
44     }
45 
46     all_rs_files
47 }
48 
snake_cased_package_name(package: &str) -> String49 fn snake_cased_package_name(package: &str) -> String {
50     if package == "_" {
51         return package.to_owned();
52     }
53 
54     package
55         .split('.')
56         .map(|s| s.to_snake_case())
57         .collect::<Vec<_>>()
58         .join(".")
59 }
60 
61 /// Rust module definition.
62 #[derive(Debug, Default)]
63 struct Module {
64     /// The name of the module.
65     name: String,
66 
67     /// The contents of the module.
68     contents: String,
69 
70     /// The names of any other modules which are submodules of this module.
71     submodules: BTreeSet<String>,
72 }
73 
74 /// Generate a lib.rs file with all prost/tonic outputs embeeded in modules which
75 /// mirror the proto packages. For the example proto file we would expect to see
76 /// the Rust output that follows it.
77 ///
78 /// ```proto
79 /// syntax = "proto3";
80 /// package examples.prost.helloworld;
81 ///
82 /// message HelloRequest {
83 ///     // Request message contains the name to be greeted
84 ///     string name = 1;
85 /// }
86 //
87 /// message HelloReply {
88 ///     // Reply contains the greeting message
89 ///     string message = 1;
90 /// }
91 /// ```
92 ///
93 /// This is expected to render out to something like the following. Note that
94 /// formatting is not applied so indentation may be missing in the actual output.
95 ///
96 /// ```ignore
97 /// pub mod examples {
98 ///     pub mod prost {
99 ///         pub mod helloworld {
100 ///             // @generated
101 ///             #[allow(clippy::derive_partial_eq_without_eq)]
102 ///             #[derive(Clone, PartialEq, ::prost::Message)]
103 ///             pub struct HelloRequest {
104 ///                 /// Request message contains the name to be greeted
105 ///                 #[prost(string, tag = "1")]
106 ///                 pub name: ::prost::alloc::string::String,
107 ///             }
108 ///             #[allow(clippy::derive_partial_eq_without_eq)]
109 ///             #[derive(Clone, PartialEq, ::prost::Message)]
110 ///             pub struct HelloReply {
111 ///                 /// Reply contains the greeting message
112 ///                 #[prost(string, tag = "1")]
113 ///                 pub message: ::prost::alloc::string::String,
114 ///             }
115 ///             // @protoc_insertion_point(module)
116 ///         }
117 ///     }
118 /// }
119 /// ```
generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String120 fn generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String {
121     let mut module_info = BTreeMap::new();
122 
123     for path in prost_outputs.iter() {
124         let mut package = path
125             .file_stem()
126             .expect("Failed to get file stem")
127             .to_str()
128             .expect("Failed to convert to str")
129             .to_string();
130 
131         if is_tonic {
132             package = package
133                 .strip_suffix(".tonic")
134                 .expect("Failed to strip suffix")
135                 .to_string()
136         };
137 
138         if package.is_empty() {
139             continue;
140         }
141 
142         let name = if package == "_" {
143             package.clone()
144         } else if package.contains('.') {
145             package
146                 .rsplit_once('.')
147                 .expect("Failed to split on '.'")
148                 .1
149                 .to_snake_case()
150                 .to_string()
151         } else {
152             package.to_snake_case()
153         };
154 
155         // Avoid a stack overflow by skipping a known bad package name
156         let module_name = snake_cased_package_name(&package);
157 
158         module_info.insert(
159             module_name.clone(),
160             Module {
161                 name,
162                 contents: fs::read_to_string(path).expect("Failed to read file"),
163                 submodules: BTreeSet::new(),
164             },
165         );
166 
167         let module_parts = module_name.split('.').collect::<Vec<_>>();
168         for parent_module_index in 0..module_parts.len() {
169             let child_module_index = parent_module_index + 1;
170             if child_module_index >= module_parts.len() {
171                 break;
172             }
173             let full_parent_module_name = module_parts[0..parent_module_index + 1].join(".");
174             let parent_module_name = module_parts[parent_module_index];
175             let child_module_name = module_parts[child_module_index];
176 
177             module_info
178                 .entry(full_parent_module_name.clone())
179                 .and_modify(|parent_module| {
180                     parent_module
181                         .submodules
182                         .insert(child_module_name.to_string());
183                 })
184                 .or_insert(Module {
185                     name: parent_module_name.to_string(),
186                     contents: "".to_string(),
187                     submodules: [child_module_name.to_string()].iter().cloned().collect(),
188                 });
189         }
190     }
191 
192     let mut content = "// @generated\n\n".to_string();
193     write_module(&mut content, &module_info, "", 0);
194     content
195 }
196 
197 /// Write out a rust module and all of its submodules.
write_module( content: &mut String, module_info: &BTreeMap<String, Module>, module_name: &str, depth: usize, )198 fn write_module(
199     content: &mut String,
200     module_info: &BTreeMap<String, Module>,
201     module_name: &str,
202     depth: usize,
203 ) {
204     if module_name.is_empty() {
205         for submodule_name in module_info.keys() {
206             write_module(content, module_info, submodule_name, depth + 1);
207         }
208         return;
209     }
210     let module = module_info.get(module_name).expect("Failed to get module");
211     let indent = "  ".repeat(depth);
212     let is_rust_module = module.name != "_";
213 
214     if is_rust_module {
215         let rust_module_name = escape_keyword(module.name.clone());
216         content
217             .write_str(&format!("{}pub mod {} {{\n", indent, rust_module_name))
218             .expect("Failed to write string");
219     }
220 
221     content
222         .write_str(&module.contents)
223         .expect("Failed to write string");
224 
225     for submodule_name in module.submodules.iter() {
226         write_module(
227             content,
228             module_info,
229             [module_name, submodule_name].join(".").as_str(),
230             depth + 1,
231         );
232     }
233 
234     if is_rust_module {
235         content
236             .write_str(&format!("{}}}\n", indent))
237             .expect("Failed to write string");
238     }
239 }
240 
241 /// ProtoPath is a path to a proto message, enum, or oneof.
242 ///
243 /// Example: `helloworld.Greeter.HelloRequest`
244 #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
245 struct ProtoPath(String);
246 
247 impl ProtoPath {
248     /// Join a component to the end of the path.
join(&self, component: &str) -> ProtoPath249     fn join(&self, component: &str) -> ProtoPath {
250         if self.0.is_empty() {
251             return ProtoPath(component.to_string());
252         }
253         if component.is_empty() {
254             return self.clone();
255         }
256 
257         ProtoPath(format!("{}.{}", self.0, component))
258     }
259 }
260 
261 impl Display for ProtoPath {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result262     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
263         write!(f, "{}", self.0)
264     }
265 }
266 
267 impl From<&str> for ProtoPath {
from(path: &str) -> Self268     fn from(path: &str) -> Self {
269         ProtoPath(path.to_string())
270     }
271 }
272 
273 /// RustModulePath is a path to a rust module.
274 ///
275 /// Example: `helloworld::greeter::HelloRequest`
276 #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
277 struct RustModulePath(String);
278 
279 impl RustModulePath {
280     /// Join a path to the end of the module path.
join(&self, path: &str) -> RustModulePath281     fn join(&self, path: &str) -> RustModulePath {
282         if self.0.is_empty() {
283             return RustModulePath(path.to_string());
284         }
285         if path.is_empty() {
286             return self.clone();
287         }
288 
289         RustModulePath(format!("{}::{}", self.0, path))
290     }
291 }
292 
293 impl Display for RustModulePath {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result294     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
295         write!(f, "{}", self.0)
296     }
297 }
298 
299 impl From<&str> for RustModulePath {
from(path: &str) -> Self300     fn from(path: &str) -> Self {
301         RustModulePath(path.to_string())
302     }
303 }
304 
305 /// Compute the `--extern_path` flags for a list of proto files. This is
306 /// expected to convert proto files into a BTreeMap of
307 /// `example.prost.helloworld`: `crate_name::example::prost::helloworld`.
get_extern_paths( descriptor_set: &FileDescriptorSet, crate_name: &str, ) -> Result<BTreeMap<ProtoPath, RustModulePath>, String>308 fn get_extern_paths(
309     descriptor_set: &FileDescriptorSet,
310     crate_name: &str,
311 ) -> Result<BTreeMap<ProtoPath, RustModulePath>, String> {
312     let mut extern_paths = BTreeMap::new();
313     let rust_path = RustModulePath(crate_name.to_string());
314 
315     for file in descriptor_set.file.iter() {
316         descriptor_set_file_to_extern_paths(&mut extern_paths, &rust_path, file);
317     }
318 
319     Ok(extern_paths)
320 }
321 
322 /// Add the extern_path pairs for a file descriptor type.
descriptor_set_file_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, rust_path: &RustModulePath, file: &FileDescriptorProto, )323 fn descriptor_set_file_to_extern_paths(
324     extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
325     rust_path: &RustModulePath,
326     file: &FileDescriptorProto,
327 ) {
328     let package = file.package.clone().unwrap_or_default();
329     let rust_path = rust_path.join(&snake_cased_package_name(&package).replace('.', "::"));
330     let proto_path = ProtoPath(package);
331 
332     for message_type in file.message_type.iter() {
333         message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, message_type);
334     }
335 
336     for enum_type in file.enum_type.iter() {
337         enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
338     }
339 }
340 
341 /// Add the extern_path pairs for a message descriptor type.
message_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, message_type: &DescriptorProto, )342 fn message_type_to_extern_paths(
343     extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
344     proto_path: &ProtoPath,
345     rust_path: &RustModulePath,
346     message_type: &DescriptorProto,
347 ) {
348     let message_type_name = message_type
349         .name
350         .as_ref()
351         .expect("Failed to get message type name");
352 
353     extern_paths.insert(
354         proto_path.join(message_type_name),
355         rust_path.join(&message_type_name.to_upper_camel_case()),
356     );
357 
358     let name_lower = message_type_name.to_lowercase();
359     let proto_path = proto_path.join(&name_lower);
360     let rust_path = rust_path.join(&name_lower);
361 
362     for nested_type in message_type.nested_type.iter() {
363         message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, nested_type)
364     }
365 
366     for enum_type in message_type.enum_type.iter() {
367         enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
368     }
369 
370     for oneof_type in message_type.oneof_decl.iter() {
371         oneof_type_to_extern_paths(extern_paths, &proto_path, &rust_path, oneof_type);
372     }
373 }
374 
375 /// Add the extern_path pairs for an enum type.
enum_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, enum_type: &EnumDescriptorProto, )376 fn enum_type_to_extern_paths(
377     extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
378     proto_path: &ProtoPath,
379     rust_path: &RustModulePath,
380     enum_type: &EnumDescriptorProto,
381 ) {
382     let enum_type_name = enum_type
383         .name
384         .as_ref()
385         .expect("Failed to get enum type name");
386     extern_paths.insert(
387         proto_path.join(enum_type_name),
388         rust_path.join(enum_type_name),
389     );
390 }
391 
oneof_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, oneof_type: &OneofDescriptorProto, )392 fn oneof_type_to_extern_paths(
393     extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
394     proto_path: &ProtoPath,
395     rust_path: &RustModulePath,
396     oneof_type: &OneofDescriptorProto,
397 ) {
398     let oneof_type_name = oneof_type
399         .name
400         .as_ref()
401         .expect("Failed to get oneof type name");
402     extern_paths.insert(
403         proto_path.join(oneof_type_name),
404         rust_path.join(oneof_type_name),
405     );
406 }
407 
408 /// The parsed command-line arguments.
409 struct Args {
410     /// The path to the protoc binary.
411     protoc: PathBuf,
412 
413     /// The path to the output directory.
414     out_dir: PathBuf,
415 
416     /// The name of the crate.
417     crate_name: String,
418 
419     /// The bazel label.
420     label: String,
421 
422     /// The path to the package info file.
423     package_info_file: PathBuf,
424 
425     /// The proto files to compile.
426     proto_files: Vec<PathBuf>,
427 
428     /// The include directories.
429     includes: Vec<String>,
430 
431     /// Dependency descriptor sets.
432     descriptor_set: PathBuf,
433 
434     /// The path to the generated lib.rs file.
435     out_librs: PathBuf,
436 
437     /// The proto include paths.
438     proto_paths: Vec<String>,
439 
440     /// The path to the rustfmt binary.
441     rustfmt: Option<PathBuf>,
442 
443     /// Whether to generate tonic code.
444     is_tonic: bool,
445 
446     /// Extra arguments to pass to protoc.
447     extra_args: Vec<String>,
448 }
449 
450 impl Args {
451     /// Parse the command-line arguments.
parse() -> Result<Args, String>452     fn parse() -> Result<Args, String> {
453         let mut protoc: Option<PathBuf> = None;
454         let mut out_dir: Option<PathBuf> = None;
455         let mut crate_name: Option<String> = None;
456         let mut package_info_file: Option<PathBuf> = None;
457         let mut proto_files: Vec<PathBuf> = Vec::new();
458         let mut includes = Vec::new();
459         let mut descriptor_set = None;
460         let mut out_librs: Option<PathBuf> = None;
461         let mut rustfmt: Option<PathBuf> = None;
462         let mut proto_paths = Vec::new();
463         let mut label: Option<String> = None;
464         let mut tonic_or_prost_opts = Vec::new();
465         let mut is_tonic = false;
466 
467         let mut extra_args = Vec::new();
468 
469         let mut handle_arg = |arg: String| {
470             if !arg.starts_with('-') {
471                 proto_files.push(PathBuf::from(arg));
472                 return;
473             }
474 
475             if arg.starts_with("-I") {
476                 includes.push(
477                     arg.strip_prefix("-I")
478                         .expect("Failed to strip -I")
479                         .to_string(),
480                 );
481                 return;
482             }
483 
484             if arg == "--is_tonic" {
485                 is_tonic = true;
486                 return;
487             }
488 
489             if !arg.contains('=') {
490                 extra_args.push(arg);
491                 return;
492             }
493 
494             let parts = arg.split_once('=').expect("Failed to split argument on =");
495             match parts {
496                 ("--protoc", value) => {
497                     protoc = Some(PathBuf::from(value));
498                 }
499                 ("--prost_out", value) => {
500                     out_dir = Some(PathBuf::from(value));
501                 }
502                 ("--package_info_output", value) => {
503                     let (key, value) = value
504                         .split_once('=')
505                         .map(|(a, b)| (a.to_string(), PathBuf::from(b)))
506                         .expect("Failed to parse package info output");
507                     crate_name = Some(key);
508                     package_info_file = Some(value);
509                 }
510                 ("--deps_info", value) => {
511                     for line in fs::read_to_string(value)
512                         .expect("Failed to read file")
513                         .lines()
514                     {
515                         let path = PathBuf::from(line.trim());
516                         for flag in fs::read_to_string(path)
517                             .expect("Failed to read file")
518                             .lines()
519                         {
520                             tonic_or_prost_opts.push(format!("extern_path={}", flag.trim()));
521                         }
522                     }
523                 }
524                 ("--descriptor_set", value) => {
525                     descriptor_set = Some(PathBuf::from(value));
526                 }
527                 ("--out_librs", value) => {
528                     out_librs = Some(PathBuf::from(value));
529                 }
530                 ("--rustfmt", value) => {
531                     rustfmt = Some(PathBuf::from(value));
532                 }
533                 ("--proto_path", value) => {
534                     proto_paths.push(value.to_string());
535                 }
536                 ("--label", value) => {
537                     label = Some(value.to_string());
538                 }
539                 (arg, value) => {
540                     extra_args.push(format!("{}={}", arg, value));
541                 }
542             }
543         };
544 
545         // Iterate over the given command line arguments parsing out arguments
546         // for the process runner and arguments for protoc and potentially spawn
547         // additional arguments needed by prost.
548         for arg in env::args().skip(1) {
549             if let Some(path) = arg.strip_prefix('@') {
550                 // handle argfile
551                 let file = std::fs::File::open(path)
552                     .map_err(|_| format!("could not open argfile: {}", arg))?;
553                 for line in std::io::BufReader::new(file).lines() {
554                     handle_arg(line.map_err(|_| format!("could not read argfile: {}", arg))?);
555                 }
556             } else {
557                 handle_arg(arg);
558             }
559         }
560 
561         for tonic_or_prost_opt in tonic_or_prost_opts {
562             extra_args.push(format!("--prost_opt={}", tonic_or_prost_opt));
563             if is_tonic {
564                 extra_args.push(format!("--tonic_opt={}", tonic_or_prost_opt));
565             }
566         }
567 
568         if protoc.is_none() {
569             return Err(
570                 "No `--protoc` value was found. Unable to parse path to proto compiler."
571                     .to_string(),
572             );
573         }
574         if out_dir.is_none() {
575             return Err(
576                 "No `--prost_out` value was found. Unable to parse output directory.".to_string(),
577             );
578         }
579         if crate_name.is_none() {
580             return Err(
581                 "No `--package_info_output` value was found. Unable to parse target crate name."
582                     .to_string(),
583             );
584         }
585         if package_info_file.is_none() {
586             return Err("No `--package_info_output` value was found. Unable to parse package info output file.".to_string());
587         }
588         if out_librs.is_none() {
589             return Err("No `--out_librs` value was found. Unable to parse the output location for all combined prost outputs.".to_string());
590         }
591         if descriptor_set.is_none() {
592             return Err(
593                 "No `--descriptor_set` value was found. Unable to parse descriptor set path."
594                     .to_string(),
595             );
596         }
597         if label.is_none() {
598             return Err(
599                 "No `--label` value was found. Unable to parse the label of the target crate."
600                     .to_string(),
601             );
602         }
603 
604         Ok(Args {
605             protoc: protoc.unwrap(),
606             out_dir: out_dir.unwrap(),
607             crate_name: crate_name.unwrap(),
608             package_info_file: package_info_file.unwrap(),
609             proto_files,
610             includes,
611             descriptor_set: descriptor_set.unwrap(),
612             out_librs: out_librs.unwrap(),
613             rustfmt,
614             proto_paths,
615             is_tonic,
616             label: label.unwrap(),
617             extra_args,
618         })
619     }
620 }
621 
622 /// Get the output directory with the label suffixed.
get_output_dir(out_dir: &Path, label: &str) -> PathBuf623 fn get_output_dir(out_dir: &Path, label: &str) -> PathBuf {
624     let label_as_path = label
625         .replace('@', "")
626         .replace("//", "_")
627         .replace(['/', ':'], "_");
628     PathBuf::from(format!(
629         "{}/prost-build-{}",
630         out_dir.display(),
631         label_as_path
632     ))
633 }
634 
635 /// Get the output directory with the label suffixed, and create it if it doesn't exist.
636 ///
637 /// This will remove the directory first if it already exists.
get_and_create_output_dir(out_dir: &Path, label: &str) -> PathBuf638 fn get_and_create_output_dir(out_dir: &Path, label: &str) -> PathBuf {
639     let out_dir = get_output_dir(out_dir, label);
640     if out_dir.exists() {
641         fs::remove_dir_all(&out_dir).expect("Failed to remove old output directory");
642     }
643     fs::create_dir_all(&out_dir).expect("Failed to create output directory");
644     out_dir
645 }
646 
647 /// Parse the descriptor set file into a `FileDescriptorSet`.
parse_descriptor_set_file(descriptor_set_path: &PathBuf) -> FileDescriptorSet648 fn parse_descriptor_set_file(descriptor_set_path: &PathBuf) -> FileDescriptorSet {
649     let descriptor_set_bytes =
650         fs::read(descriptor_set_path).expect("Failed to read descriptor set");
651     let descriptor_set = FileDescriptorSet::decode(descriptor_set_bytes.as_slice())
652         .expect("Failed to decode descriptor set");
653 
654     descriptor_set
655 }
656 
657 /// Get the package name from the descriptor set.
get_package_name(descriptor_set: &FileDescriptorSet) -> Option<String>658 fn get_package_name(descriptor_set: &FileDescriptorSet) -> Option<String> {
659     let mut package_name = None;
660 
661     for file in &descriptor_set.file {
662         if let Some(package) = &file.package {
663             package_name = Some(package.clone());
664             break;
665         }
666     }
667 
668     package_name
669 }
670 
671 /// Whether the proto file should expect to generate a .rs file.
672 ///
673 /// If the proto file contains any messages, enums, or services, then it should generate a rust file.
674 /// If the proto file only contains extensions, then it will not generate any rust files.
expect_fs_file_to_be_generated(descriptor_set: &FileDescriptorSet) -> bool675 fn expect_fs_file_to_be_generated(descriptor_set: &FileDescriptorSet) -> bool {
676     let mut expect_rs = false;
677 
678     for file in descriptor_set.file.iter() {
679         let has_messages = !file.message_type.is_empty();
680         let has_enums = !file.enum_type.is_empty();
681         let has_services = !file.service.is_empty();
682         let has_extensions = !file.extension.is_empty();
683 
684         let has_definition = has_messages || has_enums || has_services;
685 
686         if has_definition {
687             return true;
688         } else if !has_definition && !has_extensions {
689             expect_rs = true;
690         }
691     }
692 
693     expect_rs
694 }
695 
696 /// Whether the proto file should expect to generate service definitions.
has_services(descriptor_set: &FileDescriptorSet) -> bool697 fn has_services(descriptor_set: &FileDescriptorSet) -> bool {
698     descriptor_set
699         .file
700         .iter()
701         .any(|file| !file.service.is_empty())
702 }
703 
main()704 fn main() {
705     // Always enable backtraces for the protoc wrapper.
706     env::set_var("RUST_BACKTRACE", "1");
707 
708     let Args {
709         protoc,
710         out_dir,
711         crate_name,
712         label,
713         package_info_file,
714         proto_files,
715         includes,
716         descriptor_set,
717         out_librs,
718         rustfmt,
719         proto_paths,
720         is_tonic,
721         extra_args,
722     } = Args::parse().expect("Failed to parse args");
723 
724     let out_dir = get_and_create_output_dir(&out_dir, &label);
725 
726     let descriptor_set = parse_descriptor_set_file(&descriptor_set);
727     let package_name = get_package_name(&descriptor_set).unwrap_or_default();
728     let expect_rs = expect_fs_file_to_be_generated(&descriptor_set);
729     let has_services = has_services(&descriptor_set);
730 
731     if has_services && !is_tonic {
732         eprintln!("Warning: Service definitions will not be generated because the prost toolchain did not define a tonic plugin.");
733     }
734 
735     let mut cmd = process::Command::new(protoc);
736     cmd.arg(format!("--prost_out={}", out_dir.display()));
737     if is_tonic {
738         cmd.arg(format!("--tonic_out={}", out_dir.display()));
739     }
740     cmd.args(extra_args);
741     cmd.args(
742         proto_paths
743             .iter()
744             .map(|proto_path| format!("--proto_path={}", proto_path)),
745     );
746     cmd.args(includes.iter().map(|include| format!("-I{}", include)));
747     cmd.args(&proto_files);
748 
749     let status = cmd.status().expect("Failed to spawn protoc process");
750     if !status.success() {
751         panic!(
752             "protoc failed with status: {}",
753             status.code().expect("failed to get exit code")
754         );
755     }
756 
757     // Not all proto files will consistently produce `.rs` or `.tonic.rs` files. This is
758     // caused by the proto file being transpiled not having an RPC service or other protos
759     // defined (a natural and expected situation). To guarantee consistent outputs, all
760     // `.rs` files are either renamed to `.tonic.rs` if there is no `.tonic.rs` or prepended
761     // to the existing `.tonic.rs`.
762     if is_tonic {
763         let tonic_files: BTreeSet<PathBuf> = find_generated_rust_files(&out_dir);
764 
765         for tonic_file in tonic_files.iter() {
766             let tonic_path_str = tonic_file.to_str().expect("Failed to convert to str");
767             let filename = tonic_file
768                 .file_name()
769                 .expect("Failed to get file name")
770                 .to_str()
771                 .expect("Failed to convert to str");
772 
773             let is_tonic_file = filename.ends_with(".tonic.rs");
774 
775             if is_tonic_file {
776                 let rs_file_str = format!(
777                     "{}.rs",
778                     tonic_path_str
779                         .strip_suffix(".tonic.rs")
780                         .expect("Failed to strip suffix.")
781                 );
782                 let rs_file = PathBuf::from(&rs_file_str);
783 
784                 if rs_file.exists() {
785                     let rs_content = fs::read_to_string(&rs_file).expect("Failed to read file.");
786                     let tonic_content =
787                         fs::read_to_string(tonic_file).expect("Failed to read file.");
788                     fs::write(tonic_file, format!("{}\n{}", rs_content, tonic_content))
789                         .expect("Failed to write file.");
790                     fs::remove_file(&rs_file).unwrap_or_else(|err| {
791                         panic!("Failed to remove file: {err:?}: {rs_file:?}")
792                     });
793                 }
794             } else {
795                 let real_tonic_file = PathBuf::from(format!(
796                     "{}.tonic.rs",
797                     tonic_path_str
798                         .strip_suffix(".rs")
799                         .expect("Failed to strip suffix.")
800                 ));
801                 if real_tonic_file.exists() {
802                     continue;
803                 }
804                 fs::rename(tonic_file, &real_tonic_file).unwrap_or_else(|err| {
805                     panic!("Failed to rename file: {err:?}: {tonic_file:?} -> {real_tonic_file:?}");
806                 });
807             }
808         }
809     }
810 
811     // Locate all prost-generated outputs.
812     let mut rust_files = find_generated_rust_files(&out_dir);
813     if rust_files.is_empty() {
814         if expect_rs {
815             panic!("No .rs files were generated by prost.");
816         } else {
817             let file_stem = if package_name.is_empty() {
818                 "_"
819             } else {
820                 &package_name
821             };
822             let file_stem = format!("{}{}", file_stem, if is_tonic { ".tonic" } else { "" });
823             let empty_rs_file = out_dir.join(format!("{}.rs", file_stem));
824             fs::write(&empty_rs_file, "").expect("Failed to write file.");
825             rust_files.insert(empty_rs_file);
826         }
827     }
828 
829     let extern_paths = get_extern_paths(&descriptor_set, &crate_name)
830         .expect("Failed to compute proto package info");
831 
832     // Write outputs
833     fs::write(&out_librs, generate_lib_rs(&rust_files, is_tonic)).expect("Failed to write file.");
834     fs::write(
835         package_info_file,
836         extern_paths
837             .into_iter()
838             .map(|(proto_path, rust_path)| format!(".{}=::{}", proto_path, rust_path))
839             .collect::<Vec<_>>()
840             .join("\n"),
841     )
842     .expect("Failed to write file.");
843 
844     // Finally run rustfmt on the output lib.rs file
845     if let Some(rustfmt) = rustfmt {
846         let fmt_status = process::Command::new(rustfmt)
847             .arg("--edition")
848             .arg("2021")
849             .arg("--quiet")
850             .arg(&out_librs)
851             .status()
852             .expect("Failed to spawn rustfmt process");
853         if !fmt_status.success() {
854             panic!(
855                 "rustfmt failed with exit code: {}",
856                 fmt_status.code().expect("Failed to get exit code")
857             );
858         }
859     }
860 }
861 
862 /// Rust built-in keywords and reserved keywords.
863 const RUST_KEYWORDS: [&str; 51] = [
864     "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate",
865     "do", "dyn", "else", "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in",
866     "let", "loop", "macro", "match", "mod", "move", "mut", "override", "priv", "pub", "ref",
867     "return", "self", "Self", "static", "struct", "super", "trait", "true", "try", "type",
868     "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
869 ];
870 
871 /// Returns true if the given string is a Rust keyword.
is_keyword(s: &str) -> bool872 fn is_keyword(s: &str) -> bool {
873     RUST_KEYWORDS.contains(&s)
874 }
875 
876 /// Escapes a Rust keyword by prefixing it with `r#`.
escape_keyword(s: String) -> String877 fn escape_keyword(s: String) -> String {
878     if is_keyword(&s) {
879         return format!("r#{s}");
880     }
881     s
882 }
883 
884 #[cfg(test)]
885 mod test {
886 
887     use super::*;
888 
889     use prost_types::{FieldDescriptorProto, FileDescriptorProto, ServiceDescriptorProto};
890     use std::collections::BTreeMap;
891 
892     #[test]
oneof_type_to_extern_paths_test()893     fn oneof_type_to_extern_paths_test() {
894         let oneof_descriptor = OneofDescriptorProto {
895             name: Some("Foo".to_string()),
896             ..OneofDescriptorProto::default()
897         };
898 
899         {
900             let mut extern_paths = BTreeMap::new();
901             oneof_type_to_extern_paths(
902                 &mut extern_paths,
903                 &ProtoPath::from("bar"),
904                 &RustModulePath::from("bar"),
905                 &oneof_descriptor,
906             );
907 
908             assert_eq!(extern_paths.len(), 1);
909             assert_eq!(
910                 extern_paths.get(&ProtoPath::from("bar.Foo")),
911                 Some(&RustModulePath::from("bar::Foo"))
912             );
913         }
914 
915         {
916             let mut extern_paths = BTreeMap::new();
917             oneof_type_to_extern_paths(
918                 &mut extern_paths,
919                 &ProtoPath::from("bar.baz"),
920                 &RustModulePath::from("bar::baz"),
921                 &oneof_descriptor,
922             );
923 
924             assert_eq!(extern_paths.len(), 1);
925             assert_eq!(
926                 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
927                 Some(&RustModulePath::from("bar::baz::Foo"))
928             );
929         }
930     }
931 
932     #[test]
enum_type_to_extern_paths_test()933     fn enum_type_to_extern_paths_test() {
934         let enum_descriptor = EnumDescriptorProto {
935             name: Some("Foo".to_string()),
936             ..EnumDescriptorProto::default()
937         };
938 
939         {
940             let mut extern_paths = BTreeMap::new();
941             enum_type_to_extern_paths(
942                 &mut extern_paths,
943                 &ProtoPath::from("bar"),
944                 &RustModulePath::from("bar"),
945                 &enum_descriptor,
946             );
947 
948             assert_eq!(extern_paths.len(), 1);
949             assert_eq!(
950                 extern_paths.get(&ProtoPath::from("bar.Foo")),
951                 Some(&RustModulePath::from("bar::Foo"))
952             );
953         }
954 
955         {
956             let mut extern_paths = BTreeMap::new();
957             enum_type_to_extern_paths(
958                 &mut extern_paths,
959                 &ProtoPath::from("bar.baz"),
960                 &RustModulePath::from("bar::baz"),
961                 &enum_descriptor,
962             );
963 
964             assert_eq!(extern_paths.len(), 1);
965             assert_eq!(
966                 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
967                 Some(&RustModulePath::from("bar::baz::Foo"))
968             );
969         }
970     }
971 
972     #[test]
message_type_to_extern_paths_test()973     fn message_type_to_extern_paths_test() {
974         let message_descriptor = DescriptorProto {
975             name: Some("Foo".to_string()),
976             nested_type: vec![
977                 DescriptorProto {
978                     name: Some("Bar".to_string()),
979                     ..DescriptorProto::default()
980                 },
981                 DescriptorProto {
982                     name: Some("Nested".to_string()),
983                     nested_type: vec![DescriptorProto {
984                         name: Some("Baz".to_string()),
985                         enum_type: vec![EnumDescriptorProto {
986                             name: Some("Chuck".to_string()),
987                             ..EnumDescriptorProto::default()
988                         }],
989                         ..DescriptorProto::default()
990                     }],
991                     ..DescriptorProto::default()
992                 },
993             ],
994             enum_type: vec![EnumDescriptorProto {
995                 name: Some("Qux".to_string()),
996                 ..EnumDescriptorProto::default()
997             }],
998             ..DescriptorProto::default()
999         };
1000 
1001         {
1002             let mut extern_paths = BTreeMap::new();
1003             message_type_to_extern_paths(
1004                 &mut extern_paths,
1005                 &ProtoPath::from("bar"),
1006                 &RustModulePath::from("bar"),
1007                 &message_descriptor,
1008             );
1009             assert_eq!(extern_paths.len(), 6);
1010             assert_eq!(
1011                 extern_paths.get(&ProtoPath::from("bar.Foo")),
1012                 Some(&RustModulePath::from("bar::Foo"))
1013             );
1014             assert_eq!(
1015                 extern_paths.get(&ProtoPath::from("bar.foo.Bar")),
1016                 Some(&RustModulePath::from("bar::foo::Bar"))
1017             );
1018             assert_eq!(
1019                 extern_paths.get(&ProtoPath::from("bar.foo.Nested")),
1020                 Some(&RustModulePath::from("bar::foo::Nested"))
1021             );
1022             assert_eq!(
1023                 extern_paths.get(&ProtoPath::from("bar.foo.nested.Baz")),
1024                 Some(&RustModulePath::from("bar::foo::nested::Baz"))
1025             );
1026         }
1027 
1028         {
1029             let mut extern_paths = BTreeMap::new();
1030             message_type_to_extern_paths(
1031                 &mut extern_paths,
1032                 &ProtoPath::from("bar.bob"),
1033                 &RustModulePath::from("bar::bob"),
1034                 &message_descriptor,
1035             );
1036             assert_eq!(extern_paths.len(), 6);
1037             assert_eq!(
1038                 extern_paths.get(&ProtoPath::from("bar.bob.Foo")),
1039                 Some(&RustModulePath::from("bar::bob::Foo"))
1040             );
1041             assert_eq!(
1042                 extern_paths.get(&ProtoPath::from("bar.bob.foo.Bar")),
1043                 Some(&RustModulePath::from("bar::bob::foo::Bar"))
1044             );
1045             assert_eq!(
1046                 extern_paths.get(&ProtoPath::from("bar.bob.foo.Nested")),
1047                 Some(&RustModulePath::from("bar::bob::foo::Nested"))
1048             );
1049             assert_eq!(
1050                 extern_paths.get(&ProtoPath::from("bar.bob.foo.nested.Baz")),
1051                 Some(&RustModulePath::from("bar::bob::foo::nested::Baz"))
1052             );
1053         }
1054     }
1055 
1056     #[test]
proto_path_test()1057     fn proto_path_test() {
1058         {
1059             let proto_path = ProtoPath::from("");
1060             assert_eq!(proto_path.to_string(), "");
1061             assert_eq!(proto_path.join("foo"), ProtoPath::from("foo"));
1062         }
1063         {
1064             let proto_path = ProtoPath::from("foo");
1065             assert_eq!(proto_path.to_string(), "foo");
1066             assert_eq!(proto_path.join(""), ProtoPath::from("foo"));
1067         }
1068         {
1069             let proto_path = ProtoPath::from("foo");
1070             assert_eq!(proto_path.to_string(), "foo");
1071             assert_eq!(proto_path.join("bar"), ProtoPath::from("foo.bar"));
1072         }
1073         {
1074             let proto_path = ProtoPath::from("foo.bar");
1075             assert_eq!(proto_path.to_string(), "foo.bar");
1076             assert_eq!(proto_path.join("baz"), ProtoPath::from("foo.bar.baz"));
1077         }
1078         {
1079             let proto_path = ProtoPath::from("Foo.baR");
1080             assert_eq!(proto_path.to_string(), "Foo.baR");
1081             assert_eq!(proto_path.join("baz"), ProtoPath::from("Foo.baR.baz"));
1082         }
1083     }
1084 
1085     #[test]
rust_module_path_test()1086     fn rust_module_path_test() {
1087         {
1088             let rust_module_path = RustModulePath::from("");
1089             assert_eq!(rust_module_path.to_string(), "");
1090             assert_eq!(rust_module_path.join("foo"), RustModulePath::from("foo"));
1091         }
1092         {
1093             let rust_module_path = RustModulePath::from("foo");
1094             assert_eq!(rust_module_path.to_string(), "foo");
1095             assert_eq!(rust_module_path.join(""), RustModulePath::from("foo"));
1096         }
1097         {
1098             let rust_module_path = RustModulePath::from("foo");
1099             assert_eq!(rust_module_path.to_string(), "foo");
1100             assert_eq!(
1101                 rust_module_path.join("bar"),
1102                 RustModulePath::from("foo::bar")
1103             );
1104         }
1105         {
1106             let rust_module_path = RustModulePath::from("foo::bar");
1107             assert_eq!(rust_module_path.to_string(), "foo::bar");
1108             assert_eq!(
1109                 rust_module_path.join("baz"),
1110                 RustModulePath::from("foo::bar::baz")
1111             );
1112         }
1113     }
1114 
1115     #[test]
expect_fs_file_to_be_generated_test()1116     fn expect_fs_file_to_be_generated_test() {
1117         {
1118             // Empty descriptor set should create a file.
1119             let descriptor_set = FileDescriptorSet {
1120                 file: vec![FileDescriptorProto {
1121                     name: Some("foo.proto".to_string()),
1122                     ..FileDescriptorProto::default()
1123                 }],
1124             };
1125             assert!(expect_fs_file_to_be_generated(&descriptor_set));
1126         }
1127         {
1128             // Descriptor set with only message should create a file.
1129             let descriptor_set = FileDescriptorSet {
1130                 file: vec![FileDescriptorProto {
1131                     name: Some("foo.proto".to_string()),
1132                     message_type: vec![DescriptorProto {
1133                         name: Some("Foo".to_string()),
1134                         ..DescriptorProto::default()
1135                     }],
1136                     ..FileDescriptorProto::default()
1137                 }],
1138             };
1139             assert!(expect_fs_file_to_be_generated(&descriptor_set));
1140         }
1141         {
1142             // Descriptor set with only enum should create a file.
1143             let descriptor_set = FileDescriptorSet {
1144                 file: vec![FileDescriptorProto {
1145                     name: Some("foo.proto".to_string()),
1146                     enum_type: vec![EnumDescriptorProto {
1147                         name: Some("Foo".to_string()),
1148                         ..EnumDescriptorProto::default()
1149                     }],
1150                     ..FileDescriptorProto::default()
1151                 }],
1152             };
1153             assert!(expect_fs_file_to_be_generated(&descriptor_set));
1154         }
1155         {
1156             // Descriptor set with only service should create a file.
1157             let descriptor_set = FileDescriptorSet {
1158                 file: vec![FileDescriptorProto {
1159                     name: Some("foo.proto".to_string()),
1160                     service: vec![ServiceDescriptorProto {
1161                         name: Some("Foo".to_string()),
1162                         ..ServiceDescriptorProto::default()
1163                     }],
1164                     ..FileDescriptorProto::default()
1165                 }],
1166             };
1167             assert!(expect_fs_file_to_be_generated(&descriptor_set));
1168         }
1169         {
1170             // Descriptor set with only extensions should not create a file.
1171             let descriptor_set = FileDescriptorSet {
1172                 file: vec![FileDescriptorProto {
1173                     name: Some("foo.proto".to_string()),
1174                     extension: vec![FieldDescriptorProto {
1175                         name: Some("Foo".to_string()),
1176                         ..FieldDescriptorProto::default()
1177                     }],
1178                     ..FileDescriptorProto::default()
1179                 }],
1180             };
1181             assert!(!expect_fs_file_to_be_generated(&descriptor_set));
1182         }
1183     }
1184 
1185     #[test]
has_services_test()1186     fn has_services_test() {
1187         {
1188             // Empty file should not have services.
1189             let descriptor_set = FileDescriptorSet {
1190                 file: vec![FileDescriptorProto {
1191                     name: Some("foo.proto".to_string()),
1192                     ..FileDescriptorProto::default()
1193                 }],
1194             };
1195             assert!(!has_services(&descriptor_set));
1196         }
1197         {
1198             // File with only message should not have services.
1199             let descriptor_set = FileDescriptorSet {
1200                 file: vec![FileDescriptorProto {
1201                     name: Some("foo.proto".to_string()),
1202                     message_type: vec![DescriptorProto {
1203                         name: Some("Foo".to_string()),
1204                         ..DescriptorProto::default()
1205                     }],
1206                     ..FileDescriptorProto::default()
1207                 }],
1208             };
1209             assert!(!has_services(&descriptor_set));
1210         }
1211         {
1212             // File with services should have services.
1213             let descriptor_set = FileDescriptorSet {
1214                 file: vec![FileDescriptorProto {
1215                     name: Some("foo.proto".to_string()),
1216                     service: vec![ServiceDescriptorProto {
1217                         name: Some("Foo".to_string()),
1218                         ..ServiceDescriptorProto::default()
1219                     }],
1220                     ..FileDescriptorProto::default()
1221                 }],
1222             };
1223             assert!(has_services(&descriptor_set));
1224         }
1225     }
1226 
1227     #[test]
get_package_name_test()1228     fn get_package_name_test() {
1229         let descriptor_set = FileDescriptorSet {
1230             file: vec![FileDescriptorProto {
1231                 name: Some("foo.proto".to_string()),
1232                 package: Some("foo".to_string()),
1233                 ..FileDescriptorProto::default()
1234             }],
1235         };
1236 
1237         assert_eq!(get_package_name(&descriptor_set), Some("foo".to_string()));
1238     }
1239 
1240     #[test]
is_keyword_test()1241     fn is_keyword_test() {
1242         let non_keywords = [
1243             "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1244             "plugh", "xyzzy", "thud",
1245         ];
1246         for non_keyword in &non_keywords {
1247             assert!(!is_keyword(non_keyword));
1248         }
1249 
1250         for keyword in &RUST_KEYWORDS {
1251             assert!(is_keyword(keyword));
1252         }
1253     }
1254 
1255     #[test]
escape_keyword_test()1256     fn escape_keyword_test() {
1257         let non_keywords = [
1258             "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1259             "plugh", "xyzzy", "thud",
1260         ];
1261         for non_keyword in &non_keywords {
1262             assert_eq!(
1263                 escape_keyword(non_keyword.to_string()),
1264                 non_keyword.to_owned()
1265             );
1266         }
1267 
1268         for keyword in &RUST_KEYWORDS {
1269             assert_eq!(
1270                 escape_keyword(keyword.to_string()),
1271                 format!("r#{}", keyword)
1272             );
1273         }
1274     }
1275 }
1276