1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2 
3 use std::io::{Error, ErrorKind, Read};
4 use std::path::Path;
5 use std::{env, fs, io, process::Command, str};
6 
7 use derive_new::new;
8 use prost::Message;
9 use prost_build::{Config, Method, Service, ServiceGenerator};
10 use prost_types::FileDescriptorSet;
11 
12 use crate::util::{fq_grpc, to_snake_case, MethodType};
13 
14 /// Returns the names of all packages compiled.
compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>> where P: AsRef<Path>,15 pub fn compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>>
16 where
17     P: AsRef<Path>,
18 {
19     let mut prost_config = Config::new();
20     prost_config.service_generator(Box::new(Generator));
21     prost_config.out_dir(out_dir);
22 
23     // Create a file descriptor set for the protocol files.
24     let tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?;
25     std::fs::create_dir_all(tmp.path())?;
26     let descriptor_set = tmp.path().join("prost-descriptor-set");
27 
28     let mut cmd = Command::new(prost_build::protoc_from_env());
29     cmd.arg("--include_imports")
30         .arg("--include_source_info")
31         .arg("-o")
32         .arg(&descriptor_set);
33 
34     for include in includes {
35         cmd.arg("-I").arg(include.as_ref());
36     }
37 
38     // Set the protoc include after the user includes in case the user wants to
39     // override one of the built-in .protos.
40     if let Some(inc) = prost_build::protoc_include_from_env() {
41         cmd.arg("-I").arg(inc);
42     }
43 
44     for proto in protos {
45         cmd.arg(proto.as_ref());
46     }
47 
48     let output = cmd.output()?;
49     if !output.status.success() {
50         return Err(Error::new(
51             ErrorKind::Other,
52             format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)),
53         ));
54     }
55 
56     let mut buf = Vec::new();
57     fs::File::open(descriptor_set)?.read_to_end(&mut buf)?;
58     let descriptor_set = FileDescriptorSet::decode(buf.as_slice())?;
59 
60     // Get the package names from the descriptor set.
61     let mut packages: Vec<_> = descriptor_set
62         .file
63         .iter()
64         .filter_map(|f| f.package.clone())
65         .collect();
66     packages.sort();
67     packages.dedup();
68 
69     // FIXME(https://github.com/danburkert/prost/pull/155)
70     // Unfortunately we have to forget the above work and use `compile_protos` to
71     // actually generate the Rust code.
72     prost_config.compile_protos(protos, includes)?;
73 
74     Ok(packages)
75 }
76 
77 /// [`ServiceGenerator`](prost_build::ServiceGenerator) for generating grpcio services.
78 ///
79 /// Can be used for when there is a need to deviate from the common use case of
80 /// [`compile_protos()`]. One can provide a `Generator` instance to
81 /// [`prost_build::Config::service_generator()`].
82 ///
83 /// ```rust
84 /// use prost_build::Config;
85 /// use grpcio_compiler::prost_codegen::Generator;
86 ///
87 ///
88 /// fn main() -> Result<(), Box<dyn std::error::Error>> {
89 ///     let mut config = Config::new();
90 ///     config.service_generator(Box::new(Generator));
91 ///     // Modify config as needed
92 ///     config.compile_protos(&["src/frontend.proto", "src/backend.proto"], &["src"])?;
93 ///     Ok(())
94 /// }
95 /// ```
96 pub struct Generator;
97 
98 impl ServiceGenerator for Generator {
generate(&mut self, service: Service, buf: &mut String)99     fn generate(&mut self, service: Service, buf: &mut String) {
100         generate_methods(&service, buf);
101         generate_client(&service, buf);
102         generate_server(&service, buf);
103     }
104 }
105 
generate_methods(service: &Service, buf: &mut String)106 fn generate_methods(service: &Service, buf: &mut String) {
107     let service_path = if service.package.is_empty() {
108         format!("/{}", service.proto_name)
109     } else {
110         format!("/{}.{}", service.package, service.proto_name)
111     };
112 
113     for method in &service.methods {
114         generate_method(&service.name, &service_path, method, buf);
115     }
116 }
117 
const_method_name(service_name: &str, method: &Method) -> String118 fn const_method_name(service_name: &str, method: &Method) -> String {
119     format!(
120         "METHOD_{}_{}",
121         to_snake_case(service_name).to_uppercase(),
122         method.name.to_uppercase()
123     )
124 }
125 
generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String)126 fn generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String) {
127     let name = const_method_name(service_name, method);
128     let ty = format!(
129         "{}<{}, {}>",
130         fq_grpc("Method"),
131         method.input_type,
132         method.output_type
133     );
134 
135     buf.push_str("const ");
136     buf.push_str(&name);
137     buf.push_str(": ");
138     buf.push_str(&ty);
139     buf.push_str(" = ");
140     generate_method_body(service_path, method, buf);
141 }
142 
generate_method_body(service_path: &str, method: &Method, buf: &mut String)143 fn generate_method_body(service_path: &str, method: &Method, buf: &mut String) {
144     let ty = fq_grpc(&MethodType::from_method(method).to_string());
145     let pr_mar = format!(
146         "{} {{ ser: {}, de: {} }}",
147         fq_grpc("Marshaller"),
148         fq_grpc("pr_ser"),
149         fq_grpc("pr_de")
150     );
151 
152     buf.push_str(&fq_grpc("Method"));
153     buf.push('{');
154     generate_field_init("ty", &ty, buf);
155     generate_field_init(
156         "name",
157         &format!("\"{}/{}\"", service_path, method.proto_name),
158         buf,
159     );
160     generate_field_init("req_mar", &pr_mar, buf);
161     generate_field_init("resp_mar", &pr_mar, buf);
162     buf.push_str("};\n");
163 }
164 
165 // TODO share this code with protobuf codegen
166 impl MethodType {
from_method(method: &Method) -> MethodType167     fn from_method(method: &Method) -> MethodType {
168         match (method.client_streaming, method.server_streaming) {
169             (false, false) => MethodType::Unary,
170             (true, false) => MethodType::ClientStreaming,
171             (false, true) => MethodType::ServerStreaming,
172             (true, true) => MethodType::Duplex,
173         }
174     }
175 }
176 
generate_field_init(name: &str, value: &str, buf: &mut String)177 fn generate_field_init(name: &str, value: &str, buf: &mut String) {
178     buf.push_str(name);
179     buf.push_str(": ");
180     buf.push_str(value);
181     buf.push_str(", ");
182 }
183 
generate_client(service: &Service, buf: &mut String)184 fn generate_client(service: &Service, buf: &mut String) {
185     let client_name = format!("{}Client", service.name);
186     buf.push_str("#[derive(Clone)]\n");
187     buf.push_str("pub struct ");
188     buf.push_str(&client_name);
189     buf.push_str(" { pub client: ::grpcio::Client }\n");
190 
191     buf.push_str("impl ");
192     buf.push_str(&client_name);
193     buf.push_str(" {\n");
194     generate_ctor(&client_name, buf);
195     generate_client_methods(service, buf);
196     generate_spawn(buf);
197     buf.push_str("}\n")
198 }
199 
generate_ctor(client_name: &str, buf: &mut String)200 fn generate_ctor(client_name: &str, buf: &mut String) {
201     buf.push_str("pub fn new(channel: ::grpcio::Channel) -> Self { ");
202     buf.push_str(client_name);
203     buf.push_str(" { client: ::grpcio::Client::new(channel) }");
204     buf.push_str("}\n");
205 }
206 
generate_client_methods(service: &Service, buf: &mut String)207 fn generate_client_methods(service: &Service, buf: &mut String) {
208     for method in &service.methods {
209         generate_client_method(&service.name, method, buf);
210     }
211 }
212 
generate_client_method(service_name: &str, method: &Method, buf: &mut String)213 fn generate_client_method(service_name: &str, method: &Method, buf: &mut String) {
214     let name = &format!(
215         "METHOD_{}_{}",
216         to_snake_case(service_name).to_uppercase(),
217         method.name.to_uppercase()
218     );
219     match MethodType::from_method(method) {
220         MethodType::Unary => {
221             ClientMethod::new(
222                 &method.name,
223                 true,
224                 Some(&method.input_type),
225                 false,
226                 vec![&method.output_type],
227                 "unary_call",
228                 name,
229             )
230             .generate(buf);
231             ClientMethod::new(
232                 &method.name,
233                 false,
234                 Some(&method.input_type),
235                 false,
236                 vec![&method.output_type],
237                 "unary_call",
238                 name,
239             )
240             .generate(buf);
241             ClientMethod::new(
242                 &method.name,
243                 true,
244                 Some(&method.input_type),
245                 true,
246                 vec![&format!(
247                     "{}<{}>",
248                     fq_grpc("ClientUnaryReceiver"),
249                     method.output_type
250                 )],
251                 "unary_call",
252                 name,
253             )
254             .generate(buf);
255             ClientMethod::new(
256                 &method.name,
257                 false,
258                 Some(&method.input_type),
259                 true,
260                 vec![&format!(
261                     "{}<{}>",
262                     fq_grpc("ClientUnaryReceiver"),
263                     method.output_type
264                 )],
265                 "unary_call",
266                 name,
267             )
268             .generate(buf);
269         }
270         MethodType::ClientStreaming => {
271             ClientMethod::new(
272                 &method.name,
273                 true,
274                 None,
275                 false,
276                 vec![
277                     &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
278                     &format!(
279                         "{}<{}>",
280                         fq_grpc("ClientCStreamReceiver"),
281                         method.output_type
282                     ),
283                 ],
284                 "client_streaming",
285                 name,
286             )
287             .generate(buf);
288             ClientMethod::new(
289                 &method.name,
290                 false,
291                 None,
292                 false,
293                 vec![
294                     &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
295                     &format!(
296                         "{}<{}>",
297                         fq_grpc("ClientCStreamReceiver"),
298                         method.output_type
299                     ),
300                 ],
301                 "client_streaming",
302                 name,
303             )
304             .generate(buf);
305         }
306         MethodType::ServerStreaming => {
307             ClientMethod::new(
308                 &method.name,
309                 true,
310                 Some(&method.input_type),
311                 false,
312                 vec![&format!(
313                     "{}<{}>",
314                     fq_grpc("ClientSStreamReceiver"),
315                     method.output_type
316                 )],
317                 "server_streaming",
318                 name,
319             )
320             .generate(buf);
321             ClientMethod::new(
322                 &method.name,
323                 false,
324                 Some(&method.input_type),
325                 false,
326                 vec![&format!(
327                     "{}<{}>",
328                     fq_grpc("ClientSStreamReceiver"),
329                     method.output_type
330                 )],
331                 "server_streaming",
332                 name,
333             )
334             .generate(buf);
335         }
336         MethodType::Duplex => {
337             ClientMethod::new(
338                 &method.name,
339                 true,
340                 None,
341                 false,
342                 vec![
343                     &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
344                     &format!(
345                         "{}<{}>",
346                         fq_grpc("ClientDuplexReceiver"),
347                         method.output_type
348                     ),
349                 ],
350                 "duplex_streaming",
351                 name,
352             )
353             .generate(buf);
354             ClientMethod::new(
355                 &method.name,
356                 false,
357                 None,
358                 false,
359                 vec![
360                     &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
361                     &format!(
362                         "{}<{}>",
363                         fq_grpc("ClientDuplexReceiver"),
364                         method.output_type
365                     ),
366                 ],
367                 "duplex_streaming",
368                 name,
369             )
370             .generate(buf);
371         }
372     }
373 }
374 
375 #[derive(new)]
376 struct ClientMethod<'a> {
377     method_name: &'a str,
378     opt: bool,
379     request: Option<&'a str>,
380     r#async: bool,
381     result_types: Vec<&'a str>,
382     inner_method_name: &'a str,
383     data_name: &'a str,
384 }
385 
386 impl<'a> ClientMethod<'a> {
generate(&self, buf: &mut String)387     fn generate(&self, buf: &mut String) {
388         buf.push_str("pub fn ");
389 
390         buf.push_str(self.method_name);
391         if self.r#async {
392             buf.push_str("_async");
393         }
394         if self.opt {
395             buf.push_str("_opt");
396         }
397 
398         buf.push_str("(&self");
399         if let Some(req) = self.request {
400             buf.push_str(", req: &");
401             buf.push_str(req);
402         }
403         if self.opt {
404             buf.push_str(", opt: ");
405             buf.push_str(&fq_grpc("CallOption"));
406         }
407         buf.push_str(") -> ");
408 
409         buf.push_str(&fq_grpc("Result"));
410         buf.push('<');
411         if self.result_types.len() != 1 {
412             buf.push('(');
413         }
414         for rt in &self.result_types {
415             buf.push_str(rt);
416             buf.push(',');
417         }
418         if self.result_types.len() != 1 {
419             buf.push(')');
420         }
421         buf.push_str("> { ");
422         if self.opt {
423             self.generate_inner_body(buf);
424         } else {
425             self.generate_opt_body(buf);
426         }
427         buf.push_str(" }\n");
428     }
429 
430     // Method delegates to the `_opt` version of the method.
generate_opt_body(&self, buf: &mut String)431     fn generate_opt_body(&self, buf: &mut String) {
432         buf.push_str("self.");
433         buf.push_str(self.method_name);
434         if self.r#async {
435             buf.push_str("_async");
436         }
437         buf.push_str("_opt(");
438         if self.request.is_some() {
439             buf.push_str("req, ");
440         }
441         buf.push_str(&fq_grpc("CallOption::default()"));
442         buf.push(')');
443     }
444 
445     // Method delegates to the inner client.
generate_inner_body(&self, buf: &mut String)446     fn generate_inner_body(&self, buf: &mut String) {
447         buf.push_str("self.client.");
448         buf.push_str(self.inner_method_name);
449         if self.r#async {
450             buf.push_str("_async");
451         }
452         buf.push_str("(&");
453         buf.push_str(self.data_name);
454         if self.request.is_some() {
455             buf.push_str(", req");
456         }
457         buf.push_str(", opt)");
458     }
459 }
460 
generate_spawn(buf: &mut String)461 fn generate_spawn(buf: &mut String) {
462     buf.push_str(
463         "pub fn spawn<F>(&self, f: F) \
464          where F: ::std::future::Future<Output = ()> + Send + 'static {\
465          self.client.spawn(f)\
466          }\n",
467     );
468 }
469 
generate_server(service: &Service, buf: &mut String)470 fn generate_server(service: &Service, buf: &mut String) {
471     buf.push_str("pub trait ");
472     buf.push_str(&service.name);
473     buf.push_str(" {\n");
474     generate_server_methods(service, buf);
475     buf.push_str("}\n");
476 
477     buf.push_str("pub fn create_");
478     buf.push_str(&to_snake_case(&service.name));
479     buf.push_str("<S: ");
480     buf.push_str(&service.name);
481     buf.push_str(" + Send + Clone + 'static>(s: S) -> ");
482     buf.push_str(&fq_grpc("Service"));
483     buf.push_str(" {\n");
484     buf.push_str("let mut builder = ::grpcio::ServiceBuilder::new();\n");
485 
486     for method in &service.methods[0..service.methods.len() - 1] {
487         buf.push_str("let mut instance = s.clone();\n");
488         generate_method_bind(&service.name, method, buf);
489     }
490 
491     buf.push_str("let mut instance = s;\n");
492     generate_method_bind(
493         &service.name,
494         &service.methods[service.methods.len() - 1],
495         buf,
496     );
497 
498     buf.push_str("builder.build()\n");
499     buf.push_str("}\n");
500 }
501 
generate_server_methods(service: &Service, buf: &mut String)502 fn generate_server_methods(service: &Service, buf: &mut String) {
503     for method in &service.methods {
504         let method_type = MethodType::from_method(method);
505         let request_arg = match method_type {
506             MethodType::Unary | MethodType::ServerStreaming => {
507                 format!("req: {}", method.input_type)
508             }
509             MethodType::ClientStreaming | MethodType::Duplex => format!(
510                 "stream: {}<{}>",
511                 fq_grpc("RequestStream"),
512                 method.input_type
513             ),
514         };
515         let response_type = match method_type {
516             MethodType::Unary => "UnarySink",
517             MethodType::ClientStreaming => "ClientStreamingSink",
518             MethodType::ServerStreaming => "ServerStreamingSink",
519             MethodType::Duplex => "DuplexSink",
520         };
521         generate_server_method(method, &request_arg, response_type, buf);
522     }
523 }
524 
generate_server_method( method: &Method, request_arg: &str, response_type: &str, buf: &mut String, )525 fn generate_server_method(
526     method: &Method,
527     request_arg: &str,
528     response_type: &str,
529     buf: &mut String,
530 ) {
531     buf.push_str("fn ");
532     buf.push_str(&method.name);
533     buf.push_str("(&mut self, ctx: ");
534     buf.push_str(&fq_grpc("RpcContext"));
535     buf.push_str(", _");
536     buf.push_str(request_arg);
537     buf.push_str(", sink: ");
538     buf.push_str(&fq_grpc(response_type));
539     buf.push('<');
540     buf.push_str(&method.output_type);
541     buf.push('>');
542     buf.push_str(") { grpcio::unimplemented_call!(ctx, sink) }\n");
543 }
544 
generate_method_bind(service_name: &str, method: &Method, buf: &mut String)545 fn generate_method_bind(service_name: &str, method: &Method, buf: &mut String) {
546     let add_name = match MethodType::from_method(method) {
547         MethodType::Unary => "add_unary_handler",
548         MethodType::ClientStreaming => "add_client_streaming_handler",
549         MethodType::ServerStreaming => "add_server_streaming_handler",
550         MethodType::Duplex => "add_duplex_streaming_handler",
551     };
552 
553     buf.push_str("builder = builder.");
554     buf.push_str(add_name);
555     buf.push_str("(&");
556     buf.push_str(&const_method_name(service_name, method));
557     buf.push_str(", move |ctx, req, resp| instance.");
558     buf.push_str(&method.name);
559     buf.push_str("(ctx, req, resp));\n");
560 }
561 
protoc_gen_grpc_rust_main()562 pub fn protoc_gen_grpc_rust_main() {
563     let mut args = env::args();
564     args.next();
565     let (mut protos, mut includes, mut out_dir): (Vec<_>, Vec<_>, _) = Default::default();
566     for arg in args {
567         if let Some(value) = arg.strip_prefix("--protos=") {
568             protos.extend(value.split(",").map(|s| s.to_string()));
569         } else if let Some(value) = arg.strip_prefix("--includes=") {
570             includes.extend(value.split(",").map(|s| s.to_string()));
571         } else if let Some(value) = arg.strip_prefix("--out-dir=") {
572             out_dir = value.to_string();
573         }
574     }
575     if protos.is_empty() {
576         panic!("should at least specify protos to generate");
577     }
578     compile_protos(&protos, &includes, &out_dir).unwrap();
579 }
580