• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2 
3 use std::io::{Error, ErrorKind, Read};
4 use std::path::Path;
5 use std::{env, fs, io, process::Command, str};
6 
7 use derive_new::new;
8 use prost::Message;
9 use prost_build::{protoc, protoc_include, Config, Method, Service, ServiceGenerator};
10 use prost_types::FileDescriptorSet;
11 
12 use crate::util::{fq_grpc, to_snake_case, MethodType};
13 
14 /// Returns the names of all packages compiled.
compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>> where P: AsRef<Path>,15 pub fn compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>>
16 where
17     P: AsRef<Path>,
18 {
19     let mut prost_config = Config::new();
20     prost_config.service_generator(Box::new(Generator));
21     prost_config.out_dir(out_dir);
22 
23     // Create a file descriptor set for the protocol files.
24     let tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?;
25     let descriptor_set = tmp.path().join("prost-descriptor-set");
26 
27     let mut cmd = Command::new(protoc());
28     cmd.arg("--include_imports")
29         .arg("--include_source_info")
30         .arg("-o")
31         .arg(&descriptor_set);
32 
33     for include in includes {
34         cmd.arg("-I").arg(include.as_ref());
35     }
36 
37     // Set the protoc include after the user includes in case the user wants to
38     // override one of the built-in .protos.
39     cmd.arg("-I").arg(protoc_include());
40 
41     for proto in protos {
42         cmd.arg(proto.as_ref());
43     }
44 
45     let output = cmd.output()?;
46     if !output.status.success() {
47         return Err(Error::new(
48             ErrorKind::Other,
49             format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)),
50         ));
51     }
52 
53     let mut buf = Vec::new();
54     fs::File::open(descriptor_set)?.read_to_end(&mut buf)?;
55     let descriptor_set = FileDescriptorSet::decode(buf.as_slice())?;
56 
57     // Get the package names from the descriptor set.
58     let mut packages: Vec<_> = descriptor_set
59         .file
60         .iter()
61         .filter_map(|f| f.package.clone())
62         .collect();
63     packages.sort();
64     packages.dedup();
65 
66     // FIXME(https://github.com/danburkert/prost/pull/155)
67     // Unfortunately we have to forget the above work and use `compile_protos` to
68     // actually generate the Rust code.
69     prost_config.compile_protos(protos, includes)?;
70 
71     Ok(packages)
72 }
73 
74 struct Generator;
75 
76 impl ServiceGenerator for Generator {
generate(&mut self, service: Service, buf: &mut String)77     fn generate(&mut self, service: Service, buf: &mut String) {
78         generate_methods(&service, buf);
79         generate_client(&service, buf);
80         generate_server(&service, buf);
81     }
82 }
83 
generate_methods(service: &Service, buf: &mut String)84 fn generate_methods(service: &Service, buf: &mut String) {
85     let service_path = if service.package.is_empty() {
86         format!("/{}", service.proto_name)
87     } else {
88         format!("/{}.{}", service.package, service.proto_name)
89     };
90 
91     for method in &service.methods {
92         generate_method(&service.name, &service_path, method, buf);
93     }
94 }
95 
const_method_name(service_name: &str, method: &Method) -> String96 fn const_method_name(service_name: &str, method: &Method) -> String {
97     format!(
98         "METHOD_{}_{}",
99         to_snake_case(service_name).to_uppercase(),
100         method.name.to_uppercase()
101     )
102 }
103 
generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String)104 fn generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String) {
105     let name = const_method_name(service_name, method);
106     let ty = format!(
107         "{}<{}, {}>",
108         fq_grpc("Method"),
109         method.input_type,
110         method.output_type
111     );
112 
113     buf.push_str("const ");
114     buf.push_str(&name);
115     buf.push_str(": ");
116     buf.push_str(&ty);
117     buf.push_str(" = ");
118     generate_method_body(service_path, method, buf);
119 }
120 
generate_method_body(service_path: &str, method: &Method, buf: &mut String)121 fn generate_method_body(service_path: &str, method: &Method, buf: &mut String) {
122     let ty = fq_grpc(&MethodType::from_method(method).to_string());
123     let pr_mar = format!(
124         "{} {{ ser: {}, de: {} }}",
125         fq_grpc("Marshaller"),
126         fq_grpc("pr_ser"),
127         fq_grpc("pr_de")
128     );
129 
130     buf.push_str(&fq_grpc("Method"));
131     buf.push('{');
132     generate_field_init("ty", &ty, buf);
133     generate_field_init(
134         "name",
135         &format!("\"{}/{}\"", service_path, method.proto_name),
136         buf,
137     );
138     generate_field_init("req_mar", &pr_mar, buf);
139     generate_field_init("resp_mar", &pr_mar, buf);
140     buf.push_str("};\n");
141 }
142 
143 // TODO share this code with protobuf codegen
144 impl MethodType {
from_method(method: &Method) -> MethodType145     fn from_method(method: &Method) -> MethodType {
146         match (method.client_streaming, method.server_streaming) {
147             (false, false) => MethodType::Unary,
148             (true, false) => MethodType::ClientStreaming,
149             (false, true) => MethodType::ServerStreaming,
150             (true, true) => MethodType::Duplex,
151         }
152     }
153 }
154 
generate_field_init(name: &str, value: &str, buf: &mut String)155 fn generate_field_init(name: &str, value: &str, buf: &mut String) {
156     buf.push_str(name);
157     buf.push_str(": ");
158     buf.push_str(value);
159     buf.push_str(", ");
160 }
161 
generate_client(service: &Service, buf: &mut String)162 fn generate_client(service: &Service, buf: &mut String) {
163     let client_name = format!("{}Client", service.name);
164     buf.push_str("#[derive(Clone)]\n");
165     buf.push_str("pub struct ");
166     buf.push_str(&client_name);
167     buf.push_str(" { client: ::grpcio::Client }\n");
168 
169     buf.push_str("impl ");
170     buf.push_str(&client_name);
171     buf.push_str(" {\n");
172     generate_ctor(&client_name, buf);
173     generate_client_methods(service, buf);
174     generate_spawn(buf);
175     buf.push_str("}\n")
176 }
177 
generate_ctor(client_name: &str, buf: &mut String)178 fn generate_ctor(client_name: &str, buf: &mut String) {
179     buf.push_str("pub fn new(channel: ::grpcio::Channel) -> Self { ");
180     buf.push_str(client_name);
181     buf.push_str(" { client: ::grpcio::Client::new(channel) }");
182     buf.push_str("}\n");
183 }
184 
generate_client_methods(service: &Service, buf: &mut String)185 fn generate_client_methods(service: &Service, buf: &mut String) {
186     for method in &service.methods {
187         generate_client_method(&service.name, method, buf);
188     }
189 }
190 
generate_client_method(service_name: &str, method: &Method, buf: &mut String)191 fn generate_client_method(service_name: &str, method: &Method, buf: &mut String) {
192     let name = &format!(
193         "METHOD_{}_{}",
194         to_snake_case(service_name).to_uppercase(),
195         method.name.to_uppercase()
196     );
197     match MethodType::from_method(method) {
198         MethodType::Unary => {
199             ClientMethod::new(
200                 &method.name,
201                 true,
202                 Some(&method.input_type),
203                 false,
204                 vec![&method.output_type],
205                 "unary_call",
206                 name,
207             )
208             .generate(buf);
209             ClientMethod::new(
210                 &method.name,
211                 false,
212                 Some(&method.input_type),
213                 false,
214                 vec![&method.output_type],
215                 "unary_call",
216                 name,
217             )
218             .generate(buf);
219             ClientMethod::new(
220                 &method.name,
221                 true,
222                 Some(&method.input_type),
223                 true,
224                 vec![&format!(
225                     "{}<{}>",
226                     fq_grpc("ClientUnaryReceiver"),
227                     method.output_type
228                 )],
229                 "unary_call",
230                 name,
231             )
232             .generate(buf);
233             ClientMethod::new(
234                 &method.name,
235                 false,
236                 Some(&method.input_type),
237                 true,
238                 vec![&format!(
239                     "{}<{}>",
240                     fq_grpc("ClientUnaryReceiver"),
241                     method.output_type
242                 )],
243                 "unary_call",
244                 name,
245             )
246             .generate(buf);
247         }
248         MethodType::ClientStreaming => {
249             ClientMethod::new(
250                 &method.name,
251                 true,
252                 None,
253                 false,
254                 vec![
255                     &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
256                     &format!(
257                         "{}<{}>",
258                         fq_grpc("ClientCStreamReceiver"),
259                         method.output_type
260                     ),
261                 ],
262                 "client_streaming",
263                 name,
264             )
265             .generate(buf);
266             ClientMethod::new(
267                 &method.name,
268                 false,
269                 None,
270                 false,
271                 vec![
272                     &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
273                     &format!(
274                         "{}<{}>",
275                         fq_grpc("ClientCStreamReceiver"),
276                         method.output_type
277                     ),
278                 ],
279                 "client_streaming",
280                 name,
281             )
282             .generate(buf);
283         }
284         MethodType::ServerStreaming => {
285             ClientMethod::new(
286                 &method.name,
287                 true,
288                 Some(&method.input_type),
289                 false,
290                 vec![&format!(
291                     "{}<{}>",
292                     fq_grpc("ClientSStreamReceiver"),
293                     method.output_type
294                 )],
295                 "server_streaming",
296                 name,
297             )
298             .generate(buf);
299             ClientMethod::new(
300                 &method.name,
301                 false,
302                 Some(&method.input_type),
303                 false,
304                 vec![&format!(
305                     "{}<{}>",
306                     fq_grpc("ClientSStreamReceiver"),
307                     method.output_type
308                 )],
309                 "server_streaming",
310                 name,
311             )
312             .generate(buf);
313         }
314         MethodType::Duplex => {
315             ClientMethod::new(
316                 &method.name,
317                 true,
318                 None,
319                 false,
320                 vec![
321                     &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
322                     &format!(
323                         "{}<{}>",
324                         fq_grpc("ClientDuplexReceiver"),
325                         method.output_type
326                     ),
327                 ],
328                 "duplex_streaming",
329                 name,
330             )
331             .generate(buf);
332             ClientMethod::new(
333                 &method.name,
334                 false,
335                 None,
336                 false,
337                 vec![
338                     &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
339                     &format!(
340                         "{}<{}>",
341                         fq_grpc("ClientDuplexReceiver"),
342                         method.output_type
343                     ),
344                 ],
345                 "duplex_streaming",
346                 name,
347             )
348             .generate(buf);
349         }
350     }
351 }
352 
353 #[derive(new)]
354 struct ClientMethod<'a> {
355     method_name: &'a str,
356     opt: bool,
357     request: Option<&'a str>,
358     r#async: bool,
359     result_types: Vec<&'a str>,
360     inner_method_name: &'a str,
361     data_name: &'a str,
362 }
363 
364 impl<'a> ClientMethod<'a> {
generate(&self, buf: &mut String)365     fn generate(&self, buf: &mut String) {
366         buf.push_str("pub fn ");
367 
368         buf.push_str(self.method_name);
369         if self.r#async {
370             buf.push_str("_async");
371         }
372         if self.opt {
373             buf.push_str("_opt");
374         }
375 
376         buf.push_str("(&self");
377         if let Some(req) = self.request {
378             buf.push_str(", req: &");
379             buf.push_str(req);
380         }
381         if self.opt {
382             buf.push_str(", opt: ");
383             buf.push_str(&fq_grpc("CallOption"));
384         }
385         buf.push_str(") -> ");
386 
387         buf.push_str(&fq_grpc("Result"));
388         buf.push('<');
389         if self.result_types.len() != 1 {
390             buf.push('(');
391         }
392         for rt in &self.result_types {
393             buf.push_str(rt);
394             buf.push(',');
395         }
396         if self.result_types.len() != 1 {
397             buf.push(')');
398         }
399         buf.push_str("> { ");
400         if self.opt {
401             self.generate_inner_body(buf);
402         } else {
403             self.generate_opt_body(buf);
404         }
405         buf.push_str(" }\n");
406     }
407 
408     // Method delegates to the `_opt` version of the method.
generate_opt_body(&self, buf: &mut String)409     fn generate_opt_body(&self, buf: &mut String) {
410         buf.push_str("self.");
411         buf.push_str(self.method_name);
412         if self.r#async {
413             buf.push_str("_async");
414         }
415         buf.push_str("_opt(");
416         if self.request.is_some() {
417             buf.push_str("req, ");
418         }
419         buf.push_str(&fq_grpc("CallOption::default()"));
420         buf.push(')');
421     }
422 
423     // Method delegates to the inner client.
generate_inner_body(&self, buf: &mut String)424     fn generate_inner_body(&self, buf: &mut String) {
425         buf.push_str("self.client.");
426         buf.push_str(self.inner_method_name);
427         if self.r#async {
428             buf.push_str("_async");
429         }
430         buf.push_str("(&");
431         buf.push_str(self.data_name);
432         if self.request.is_some() {
433             buf.push_str(", req");
434         }
435         buf.push_str(", opt)");
436     }
437 }
438 
generate_spawn(buf: &mut String)439 fn generate_spawn(buf: &mut String) {
440     buf.push_str(
441         "pub fn spawn<F>(&self, f: F) \
442          where F: ::futures::Future<Output = ()> + Send + 'static {\
443          self.client.spawn(f)\
444          }\n",
445     );
446 }
447 
generate_server(service: &Service, buf: &mut String)448 fn generate_server(service: &Service, buf: &mut String) {
449     buf.push_str("pub trait ");
450     buf.push_str(&service.name);
451     buf.push_str(" {\n");
452     generate_server_methods(service, buf);
453     buf.push_str("}\n");
454 
455     buf.push_str("pub fn create_");
456     buf.push_str(&to_snake_case(&service.name));
457     buf.push_str("<S: ");
458     buf.push_str(&service.name);
459     buf.push_str(" + Send + Clone + 'static>(s: S) -> ");
460     buf.push_str(&fq_grpc("Service"));
461     buf.push_str(" {\n");
462     buf.push_str("let mut builder = ::grpcio::ServiceBuilder::new();\n");
463 
464     for method in &service.methods[0..service.methods.len() - 1] {
465         buf.push_str("let mut instance = s.clone();\n");
466         generate_method_bind(&service.name, method, buf);
467     }
468 
469     buf.push_str("let mut instance = s;\n");
470     generate_method_bind(
471         &service.name,
472         &service.methods[service.methods.len() - 1],
473         buf,
474     );
475 
476     buf.push_str("builder.build()\n");
477     buf.push_str("}\n");
478 }
479 
generate_server_methods(service: &Service, buf: &mut String)480 fn generate_server_methods(service: &Service, buf: &mut String) {
481     for method in &service.methods {
482         let method_type = MethodType::from_method(method);
483         let request_arg = match method_type {
484             MethodType::Unary | MethodType::ServerStreaming => {
485                 format!("req: {}", method.input_type)
486             }
487             MethodType::ClientStreaming | MethodType::Duplex => format!(
488                 "stream: {}<{}>",
489                 fq_grpc("RequestStream"),
490                 method.input_type
491             ),
492         };
493         let response_type = match method_type {
494             MethodType::Unary => "UnarySink",
495             MethodType::ClientStreaming => "ClientStreamingSink",
496             MethodType::ServerStreaming => "ServerStreamingSink",
497             MethodType::Duplex => "DuplexSink",
498         };
499         generate_server_method(method, &request_arg, response_type, buf);
500     }
501 }
502 
generate_server_method( method: &Method, request_arg: &str, response_type: &str, buf: &mut String, )503 fn generate_server_method(
504     method: &Method,
505     request_arg: &str,
506     response_type: &str,
507     buf: &mut String,
508 ) {
509     buf.push_str("fn ");
510     buf.push_str(&method.name);
511     buf.push_str("(&mut self, ctx: ");
512     buf.push_str(&fq_grpc("RpcContext"));
513     buf.push_str(", _");
514     buf.push_str(request_arg);
515     buf.push_str(", sink: ");
516     buf.push_str(&fq_grpc(response_type));
517     buf.push('<');
518     buf.push_str(&method.output_type);
519     buf.push('>');
520     buf.push_str(") { grpcio::unimplemented_call!(ctx, sink) }\n");
521 }
522 
generate_method_bind(service_name: &str, method: &Method, buf: &mut String)523 fn generate_method_bind(service_name: &str, method: &Method, buf: &mut String) {
524     let add_name = match MethodType::from_method(method) {
525         MethodType::Unary => "add_unary_handler",
526         MethodType::ClientStreaming => "add_client_streaming_handler",
527         MethodType::ServerStreaming => "add_server_streaming_handler",
528         MethodType::Duplex => "add_duplex_streaming_handler",
529     };
530 
531     buf.push_str("builder = builder.");
532     buf.push_str(add_name);
533     buf.push_str("(&");
534     buf.push_str(&const_method_name(service_name, method));
535     buf.push_str(", move |ctx, req, resp| instance.");
536     buf.push_str(&method.name);
537     buf.push_str("(ctx, req, resp));\n");
538 }
539 
protoc_gen_grpc_rust_main()540 pub fn protoc_gen_grpc_rust_main() {
541     let mut args = env::args();
542     args.next();
543     let (mut protos, mut includes, mut out_dir): (Vec<_>, Vec<_>, _) = Default::default();
544     for arg in args {
545         if let Some(value) = arg.strip_prefix("--protos=") {
546             protos.extend(value.split(",").map(|s| s.to_string()));
547         } else if let Some(value) = arg.strip_prefix("--includes=") {
548             includes.extend(value.split(",").map(|s| s.to_string()));
549         } else if let Some(value) = arg.strip_prefix("--out-dir=") {
550             out_dir = value.to_string();
551         }
552     }
553     if protos.is_empty() {
554         panic!("should at least specify protos to generate");
555     }
556     compile_protos(&protos, &includes, &out_dir).unwrap();
557 }
558