1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2
3 use std::io::{Error, ErrorKind, Read};
4 use std::path::Path;
5 use std::{env, fs, io, process::Command, str};
6
7 use derive_new::new;
8 use prost::Message;
9 use prost_build::{Config, Method, Service, ServiceGenerator};
10 use prost_types::FileDescriptorSet;
11
12 use crate::util::{fq_grpc, to_snake_case, MethodType};
13
14 /// Returns the names of all packages compiled.
compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>> where P: AsRef<Path>,15 pub fn compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>>
16 where
17 P: AsRef<Path>,
18 {
19 let mut prost_config = Config::new();
20 prost_config.service_generator(Box::new(Generator));
21 prost_config.out_dir(out_dir);
22
23 // Create a file descriptor set for the protocol files.
24 let tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?;
25 std::fs::create_dir_all(tmp.path())?;
26 let descriptor_set = tmp.path().join("prost-descriptor-set");
27
28 let mut cmd = Command::new(prost_build::protoc_from_env());
29 cmd.arg("--include_imports")
30 .arg("--include_source_info")
31 .arg("-o")
32 .arg(&descriptor_set);
33
34 for include in includes {
35 cmd.arg("-I").arg(include.as_ref());
36 }
37
38 // Set the protoc include after the user includes in case the user wants to
39 // override one of the built-in .protos.
40 if let Some(inc) = prost_build::protoc_include_from_env() {
41 cmd.arg("-I").arg(inc);
42 }
43
44 for proto in protos {
45 cmd.arg(proto.as_ref());
46 }
47
48 let output = cmd.output()?;
49 if !output.status.success() {
50 return Err(Error::new(
51 ErrorKind::Other,
52 format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)),
53 ));
54 }
55
56 let mut buf = Vec::new();
57 fs::File::open(descriptor_set)?.read_to_end(&mut buf)?;
58 let descriptor_set = FileDescriptorSet::decode(buf.as_slice())?;
59
60 // Get the package names from the descriptor set.
61 let mut packages: Vec<_> = descriptor_set
62 .file
63 .iter()
64 .filter_map(|f| f.package.clone())
65 .collect();
66 packages.sort();
67 packages.dedup();
68
69 // FIXME(https://github.com/danburkert/prost/pull/155)
70 // Unfortunately we have to forget the above work and use `compile_protos` to
71 // actually generate the Rust code.
72 prost_config.compile_protos(protos, includes)?;
73
74 Ok(packages)
75 }
76
77 /// [`ServiceGenerator`](prost_build::ServiceGenerator) for generating grpcio services.
78 ///
79 /// Can be used for when there is a need to deviate from the common use case of
80 /// [`compile_protos()`]. One can provide a `Generator` instance to
81 /// [`prost_build::Config::service_generator()`].
82 ///
83 /// ```rust
84 /// use prost_build::Config;
85 /// use grpcio_compiler::prost_codegen::Generator;
86 ///
87 ///
88 /// fn main() -> Result<(), Box<dyn std::error::Error>> {
89 /// let mut config = Config::new();
90 /// config.service_generator(Box::new(Generator));
91 /// // Modify config as needed
92 /// config.compile_protos(&["src/frontend.proto", "src/backend.proto"], &["src"])?;
93 /// Ok(())
94 /// }
95 /// ```
96 pub struct Generator;
97
98 impl ServiceGenerator for Generator {
generate(&mut self, service: Service, buf: &mut String)99 fn generate(&mut self, service: Service, buf: &mut String) {
100 generate_methods(&service, buf);
101 generate_client(&service, buf);
102 generate_server(&service, buf);
103 }
104 }
105
generate_methods(service: &Service, buf: &mut String)106 fn generate_methods(service: &Service, buf: &mut String) {
107 let service_path = if service.package.is_empty() {
108 format!("/{}", service.proto_name)
109 } else {
110 format!("/{}.{}", service.package, service.proto_name)
111 };
112
113 for method in &service.methods {
114 generate_method(&service.name, &service_path, method, buf);
115 }
116 }
117
const_method_name(service_name: &str, method: &Method) -> String118 fn const_method_name(service_name: &str, method: &Method) -> String {
119 format!(
120 "METHOD_{}_{}",
121 to_snake_case(service_name).to_uppercase(),
122 method.name.to_uppercase()
123 )
124 }
125
generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String)126 fn generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String) {
127 let name = const_method_name(service_name, method);
128 let ty = format!(
129 "{}<{}, {}>",
130 fq_grpc("Method"),
131 method.input_type,
132 method.output_type
133 );
134
135 buf.push_str("const ");
136 buf.push_str(&name);
137 buf.push_str(": ");
138 buf.push_str(&ty);
139 buf.push_str(" = ");
140 generate_method_body(service_path, method, buf);
141 }
142
generate_method_body(service_path: &str, method: &Method, buf: &mut String)143 fn generate_method_body(service_path: &str, method: &Method, buf: &mut String) {
144 let ty = fq_grpc(&MethodType::from_method(method).to_string());
145 let pr_mar = format!(
146 "{} {{ ser: {}, de: {} }}",
147 fq_grpc("Marshaller"),
148 fq_grpc("pr_ser"),
149 fq_grpc("pr_de")
150 );
151
152 buf.push_str(&fq_grpc("Method"));
153 buf.push('{');
154 generate_field_init("ty", &ty, buf);
155 generate_field_init(
156 "name",
157 &format!("\"{}/{}\"", service_path, method.proto_name),
158 buf,
159 );
160 generate_field_init("req_mar", &pr_mar, buf);
161 generate_field_init("resp_mar", &pr_mar, buf);
162 buf.push_str("};\n");
163 }
164
165 // TODO share this code with protobuf codegen
166 impl MethodType {
from_method(method: &Method) -> MethodType167 fn from_method(method: &Method) -> MethodType {
168 match (method.client_streaming, method.server_streaming) {
169 (false, false) => MethodType::Unary,
170 (true, false) => MethodType::ClientStreaming,
171 (false, true) => MethodType::ServerStreaming,
172 (true, true) => MethodType::Duplex,
173 }
174 }
175 }
176
generate_field_init(name: &str, value: &str, buf: &mut String)177 fn generate_field_init(name: &str, value: &str, buf: &mut String) {
178 buf.push_str(name);
179 buf.push_str(": ");
180 buf.push_str(value);
181 buf.push_str(", ");
182 }
183
generate_client(service: &Service, buf: &mut String)184 fn generate_client(service: &Service, buf: &mut String) {
185 let client_name = format!("{}Client", service.name);
186 buf.push_str("#[derive(Clone)]\n");
187 buf.push_str("pub struct ");
188 buf.push_str(&client_name);
189 buf.push_str(" { pub client: ::grpcio::Client }\n");
190
191 buf.push_str("impl ");
192 buf.push_str(&client_name);
193 buf.push_str(" {\n");
194 generate_ctor(&client_name, buf);
195 generate_client_methods(service, buf);
196 generate_spawn(buf);
197 buf.push_str("}\n")
198 }
199
generate_ctor(client_name: &str, buf: &mut String)200 fn generate_ctor(client_name: &str, buf: &mut String) {
201 buf.push_str("pub fn new(channel: ::grpcio::Channel) -> Self { ");
202 buf.push_str(client_name);
203 buf.push_str(" { client: ::grpcio::Client::new(channel) }");
204 buf.push_str("}\n");
205 }
206
generate_client_methods(service: &Service, buf: &mut String)207 fn generate_client_methods(service: &Service, buf: &mut String) {
208 for method in &service.methods {
209 generate_client_method(&service.name, method, buf);
210 }
211 }
212
generate_client_method(service_name: &str, method: &Method, buf: &mut String)213 fn generate_client_method(service_name: &str, method: &Method, buf: &mut String) {
214 let name = &format!(
215 "METHOD_{}_{}",
216 to_snake_case(service_name).to_uppercase(),
217 method.name.to_uppercase()
218 );
219 match MethodType::from_method(method) {
220 MethodType::Unary => {
221 ClientMethod::new(
222 &method.name,
223 true,
224 Some(&method.input_type),
225 false,
226 vec![&method.output_type],
227 "unary_call",
228 name,
229 )
230 .generate(buf);
231 ClientMethod::new(
232 &method.name,
233 false,
234 Some(&method.input_type),
235 false,
236 vec![&method.output_type],
237 "unary_call",
238 name,
239 )
240 .generate(buf);
241 ClientMethod::new(
242 &method.name,
243 true,
244 Some(&method.input_type),
245 true,
246 vec![&format!(
247 "{}<{}>",
248 fq_grpc("ClientUnaryReceiver"),
249 method.output_type
250 )],
251 "unary_call",
252 name,
253 )
254 .generate(buf);
255 ClientMethod::new(
256 &method.name,
257 false,
258 Some(&method.input_type),
259 true,
260 vec![&format!(
261 "{}<{}>",
262 fq_grpc("ClientUnaryReceiver"),
263 method.output_type
264 )],
265 "unary_call",
266 name,
267 )
268 .generate(buf);
269 }
270 MethodType::ClientStreaming => {
271 ClientMethod::new(
272 &method.name,
273 true,
274 None,
275 false,
276 vec![
277 &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
278 &format!(
279 "{}<{}>",
280 fq_grpc("ClientCStreamReceiver"),
281 method.output_type
282 ),
283 ],
284 "client_streaming",
285 name,
286 )
287 .generate(buf);
288 ClientMethod::new(
289 &method.name,
290 false,
291 None,
292 false,
293 vec![
294 &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
295 &format!(
296 "{}<{}>",
297 fq_grpc("ClientCStreamReceiver"),
298 method.output_type
299 ),
300 ],
301 "client_streaming",
302 name,
303 )
304 .generate(buf);
305 }
306 MethodType::ServerStreaming => {
307 ClientMethod::new(
308 &method.name,
309 true,
310 Some(&method.input_type),
311 false,
312 vec![&format!(
313 "{}<{}>",
314 fq_grpc("ClientSStreamReceiver"),
315 method.output_type
316 )],
317 "server_streaming",
318 name,
319 )
320 .generate(buf);
321 ClientMethod::new(
322 &method.name,
323 false,
324 Some(&method.input_type),
325 false,
326 vec![&format!(
327 "{}<{}>",
328 fq_grpc("ClientSStreamReceiver"),
329 method.output_type
330 )],
331 "server_streaming",
332 name,
333 )
334 .generate(buf);
335 }
336 MethodType::Duplex => {
337 ClientMethod::new(
338 &method.name,
339 true,
340 None,
341 false,
342 vec![
343 &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
344 &format!(
345 "{}<{}>",
346 fq_grpc("ClientDuplexReceiver"),
347 method.output_type
348 ),
349 ],
350 "duplex_streaming",
351 name,
352 )
353 .generate(buf);
354 ClientMethod::new(
355 &method.name,
356 false,
357 None,
358 false,
359 vec![
360 &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
361 &format!(
362 "{}<{}>",
363 fq_grpc("ClientDuplexReceiver"),
364 method.output_type
365 ),
366 ],
367 "duplex_streaming",
368 name,
369 )
370 .generate(buf);
371 }
372 }
373 }
374
375 #[derive(new)]
376 struct ClientMethod<'a> {
377 method_name: &'a str,
378 opt: bool,
379 request: Option<&'a str>,
380 r#async: bool,
381 result_types: Vec<&'a str>,
382 inner_method_name: &'a str,
383 data_name: &'a str,
384 }
385
386 impl<'a> ClientMethod<'a> {
generate(&self, buf: &mut String)387 fn generate(&self, buf: &mut String) {
388 buf.push_str("pub fn ");
389
390 buf.push_str(self.method_name);
391 if self.r#async {
392 buf.push_str("_async");
393 }
394 if self.opt {
395 buf.push_str("_opt");
396 }
397
398 buf.push_str("(&self");
399 if let Some(req) = self.request {
400 buf.push_str(", req: &");
401 buf.push_str(req);
402 }
403 if self.opt {
404 buf.push_str(", opt: ");
405 buf.push_str(&fq_grpc("CallOption"));
406 }
407 buf.push_str(") -> ");
408
409 buf.push_str(&fq_grpc("Result"));
410 buf.push('<');
411 if self.result_types.len() != 1 {
412 buf.push('(');
413 }
414 for rt in &self.result_types {
415 buf.push_str(rt);
416 buf.push(',');
417 }
418 if self.result_types.len() != 1 {
419 buf.push(')');
420 }
421 buf.push_str("> { ");
422 if self.opt {
423 self.generate_inner_body(buf);
424 } else {
425 self.generate_opt_body(buf);
426 }
427 buf.push_str(" }\n");
428 }
429
430 // Method delegates to the `_opt` version of the method.
generate_opt_body(&self, buf: &mut String)431 fn generate_opt_body(&self, buf: &mut String) {
432 buf.push_str("self.");
433 buf.push_str(self.method_name);
434 if self.r#async {
435 buf.push_str("_async");
436 }
437 buf.push_str("_opt(");
438 if self.request.is_some() {
439 buf.push_str("req, ");
440 }
441 buf.push_str(&fq_grpc("CallOption::default()"));
442 buf.push(')');
443 }
444
445 // Method delegates to the inner client.
generate_inner_body(&self, buf: &mut String)446 fn generate_inner_body(&self, buf: &mut String) {
447 buf.push_str("self.client.");
448 buf.push_str(self.inner_method_name);
449 if self.r#async {
450 buf.push_str("_async");
451 }
452 buf.push_str("(&");
453 buf.push_str(self.data_name);
454 if self.request.is_some() {
455 buf.push_str(", req");
456 }
457 buf.push_str(", opt)");
458 }
459 }
460
generate_spawn(buf: &mut String)461 fn generate_spawn(buf: &mut String) {
462 buf.push_str(
463 "pub fn spawn<F>(&self, f: F) \
464 where F: ::std::future::Future<Output = ()> + Send + 'static {\
465 self.client.spawn(f)\
466 }\n",
467 );
468 }
469
generate_server(service: &Service, buf: &mut String)470 fn generate_server(service: &Service, buf: &mut String) {
471 buf.push_str("pub trait ");
472 buf.push_str(&service.name);
473 buf.push_str(" {\n");
474 generate_server_methods(service, buf);
475 buf.push_str("}\n");
476
477 buf.push_str("pub fn create_");
478 buf.push_str(&to_snake_case(&service.name));
479 buf.push_str("<S: ");
480 buf.push_str(&service.name);
481 buf.push_str(" + Send + Clone + 'static>(s: S) -> ");
482 buf.push_str(&fq_grpc("Service"));
483 buf.push_str(" {\n");
484 buf.push_str("let mut builder = ::grpcio::ServiceBuilder::new();\n");
485
486 for method in &service.methods[0..service.methods.len() - 1] {
487 buf.push_str("let mut instance = s.clone();\n");
488 generate_method_bind(&service.name, method, buf);
489 }
490
491 buf.push_str("let mut instance = s;\n");
492 generate_method_bind(
493 &service.name,
494 &service.methods[service.methods.len() - 1],
495 buf,
496 );
497
498 buf.push_str("builder.build()\n");
499 buf.push_str("}\n");
500 }
501
generate_server_methods(service: &Service, buf: &mut String)502 fn generate_server_methods(service: &Service, buf: &mut String) {
503 for method in &service.methods {
504 let method_type = MethodType::from_method(method);
505 let request_arg = match method_type {
506 MethodType::Unary | MethodType::ServerStreaming => {
507 format!("req: {}", method.input_type)
508 }
509 MethodType::ClientStreaming | MethodType::Duplex => format!(
510 "stream: {}<{}>",
511 fq_grpc("RequestStream"),
512 method.input_type
513 ),
514 };
515 let response_type = match method_type {
516 MethodType::Unary => "UnarySink",
517 MethodType::ClientStreaming => "ClientStreamingSink",
518 MethodType::ServerStreaming => "ServerStreamingSink",
519 MethodType::Duplex => "DuplexSink",
520 };
521 generate_server_method(method, &request_arg, response_type, buf);
522 }
523 }
524
generate_server_method( method: &Method, request_arg: &str, response_type: &str, buf: &mut String, )525 fn generate_server_method(
526 method: &Method,
527 request_arg: &str,
528 response_type: &str,
529 buf: &mut String,
530 ) {
531 buf.push_str("fn ");
532 buf.push_str(&method.name);
533 buf.push_str("(&mut self, ctx: ");
534 buf.push_str(&fq_grpc("RpcContext"));
535 buf.push_str(", _");
536 buf.push_str(request_arg);
537 buf.push_str(", sink: ");
538 buf.push_str(&fq_grpc(response_type));
539 buf.push('<');
540 buf.push_str(&method.output_type);
541 buf.push('>');
542 buf.push_str(") { grpcio::unimplemented_call!(ctx, sink) }\n");
543 }
544
generate_method_bind(service_name: &str, method: &Method, buf: &mut String)545 fn generate_method_bind(service_name: &str, method: &Method, buf: &mut String) {
546 let add_name = match MethodType::from_method(method) {
547 MethodType::Unary => "add_unary_handler",
548 MethodType::ClientStreaming => "add_client_streaming_handler",
549 MethodType::ServerStreaming => "add_server_streaming_handler",
550 MethodType::Duplex => "add_duplex_streaming_handler",
551 };
552
553 buf.push_str("builder = builder.");
554 buf.push_str(add_name);
555 buf.push_str("(&");
556 buf.push_str(&const_method_name(service_name, method));
557 buf.push_str(", move |ctx, req, resp| instance.");
558 buf.push_str(&method.name);
559 buf.push_str("(ctx, req, resp));\n");
560 }
561
protoc_gen_grpc_rust_main()562 pub fn protoc_gen_grpc_rust_main() {
563 let mut args = env::args();
564 args.next();
565 let (mut protos, mut includes, mut out_dir): (Vec<_>, Vec<_>, _) = Default::default();
566 for arg in args {
567 if let Some(value) = arg.strip_prefix("--protos=") {
568 protos.extend(value.split(",").map(|s| s.to_string()));
569 } else if let Some(value) = arg.strip_prefix("--includes=") {
570 includes.extend(value.split(",").map(|s| s.to_string()));
571 } else if let Some(value) = arg.strip_prefix("--out-dir=") {
572 out_dir = value.to_string();
573 }
574 }
575 if protos.is_empty() {
576 panic!("should at least specify protos to generate");
577 }
578 compile_protos(&protos, &includes, &out_dir).unwrap();
579 }
580