1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2
3 use std::io::{Error, ErrorKind, Read};
4 use std::path::Path;
5 use std::{env, fs, io, process::Command, str};
6
7 use derive_new::new;
8 use prost::Message;
9 use prost_build::{protoc, protoc_include, Config, Method, Service, ServiceGenerator};
10 use prost_types::FileDescriptorSet;
11
12 use crate::util::{fq_grpc, to_snake_case, MethodType};
13
14 /// Returns the names of all packages compiled.
compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>> where P: AsRef<Path>,15 pub fn compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>>
16 where
17 P: AsRef<Path>,
18 {
19 let mut prost_config = Config::new();
20 prost_config.service_generator(Box::new(Generator));
21 prost_config.out_dir(out_dir);
22
23 // Create a file descriptor set for the protocol files.
24 let tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?;
25 let descriptor_set = tmp.path().join("prost-descriptor-set");
26
27 let mut cmd = Command::new(protoc());
28 cmd.arg("--include_imports")
29 .arg("--include_source_info")
30 .arg("-o")
31 .arg(&descriptor_set);
32
33 for include in includes {
34 cmd.arg("-I").arg(include.as_ref());
35 }
36
37 // Set the protoc include after the user includes in case the user wants to
38 // override one of the built-in .protos.
39 cmd.arg("-I").arg(protoc_include());
40
41 for proto in protos {
42 cmd.arg(proto.as_ref());
43 }
44
45 let output = cmd.output()?;
46 if !output.status.success() {
47 return Err(Error::new(
48 ErrorKind::Other,
49 format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)),
50 ));
51 }
52
53 let mut buf = Vec::new();
54 fs::File::open(descriptor_set)?.read_to_end(&mut buf)?;
55 let descriptor_set = FileDescriptorSet::decode(buf.as_slice())?;
56
57 // Get the package names from the descriptor set.
58 let mut packages: Vec<_> = descriptor_set
59 .file
60 .iter()
61 .filter_map(|f| f.package.clone())
62 .collect();
63 packages.sort();
64 packages.dedup();
65
66 // FIXME(https://github.com/danburkert/prost/pull/155)
67 // Unfortunately we have to forget the above work and use `compile_protos` to
68 // actually generate the Rust code.
69 prost_config.compile_protos(protos, includes)?;
70
71 Ok(packages)
72 }
73
74 struct Generator;
75
76 impl ServiceGenerator for Generator {
generate(&mut self, service: Service, buf: &mut String)77 fn generate(&mut self, service: Service, buf: &mut String) {
78 generate_methods(&service, buf);
79 generate_client(&service, buf);
80 generate_server(&service, buf);
81 }
82 }
83
generate_methods(service: &Service, buf: &mut String)84 fn generate_methods(service: &Service, buf: &mut String) {
85 let service_path = if service.package.is_empty() {
86 format!("/{}", service.proto_name)
87 } else {
88 format!("/{}.{}", service.package, service.proto_name)
89 };
90
91 for method in &service.methods {
92 generate_method(&service.name, &service_path, method, buf);
93 }
94 }
95
const_method_name(service_name: &str, method: &Method) -> String96 fn const_method_name(service_name: &str, method: &Method) -> String {
97 format!(
98 "METHOD_{}_{}",
99 to_snake_case(service_name).to_uppercase(),
100 method.name.to_uppercase()
101 )
102 }
103
generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String)104 fn generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String) {
105 let name = const_method_name(service_name, method);
106 let ty = format!(
107 "{}<{}, {}>",
108 fq_grpc("Method"),
109 method.input_type,
110 method.output_type
111 );
112
113 buf.push_str("const ");
114 buf.push_str(&name);
115 buf.push_str(": ");
116 buf.push_str(&ty);
117 buf.push_str(" = ");
118 generate_method_body(service_path, method, buf);
119 }
120
generate_method_body(service_path: &str, method: &Method, buf: &mut String)121 fn generate_method_body(service_path: &str, method: &Method, buf: &mut String) {
122 let ty = fq_grpc(&MethodType::from_method(method).to_string());
123 let pr_mar = format!(
124 "{} {{ ser: {}, de: {} }}",
125 fq_grpc("Marshaller"),
126 fq_grpc("pr_ser"),
127 fq_grpc("pr_de")
128 );
129
130 buf.push_str(&fq_grpc("Method"));
131 buf.push('{');
132 generate_field_init("ty", &ty, buf);
133 generate_field_init(
134 "name",
135 &format!("\"{}/{}\"", service_path, method.proto_name),
136 buf,
137 );
138 generate_field_init("req_mar", &pr_mar, buf);
139 generate_field_init("resp_mar", &pr_mar, buf);
140 buf.push_str("};\n");
141 }
142
143 // TODO share this code with protobuf codegen
144 impl MethodType {
from_method(method: &Method) -> MethodType145 fn from_method(method: &Method) -> MethodType {
146 match (method.client_streaming, method.server_streaming) {
147 (false, false) => MethodType::Unary,
148 (true, false) => MethodType::ClientStreaming,
149 (false, true) => MethodType::ServerStreaming,
150 (true, true) => MethodType::Duplex,
151 }
152 }
153 }
154
generate_field_init(name: &str, value: &str, buf: &mut String)155 fn generate_field_init(name: &str, value: &str, buf: &mut String) {
156 buf.push_str(name);
157 buf.push_str(": ");
158 buf.push_str(value);
159 buf.push_str(", ");
160 }
161
generate_client(service: &Service, buf: &mut String)162 fn generate_client(service: &Service, buf: &mut String) {
163 let client_name = format!("{}Client", service.name);
164 buf.push_str("#[derive(Clone)]\n");
165 buf.push_str("pub struct ");
166 buf.push_str(&client_name);
167 buf.push_str(" { client: ::grpcio::Client }\n");
168
169 buf.push_str("impl ");
170 buf.push_str(&client_name);
171 buf.push_str(" {\n");
172 generate_ctor(&client_name, buf);
173 generate_client_methods(service, buf);
174 generate_spawn(buf);
175 buf.push_str("}\n")
176 }
177
generate_ctor(client_name: &str, buf: &mut String)178 fn generate_ctor(client_name: &str, buf: &mut String) {
179 buf.push_str("pub fn new(channel: ::grpcio::Channel) -> Self { ");
180 buf.push_str(client_name);
181 buf.push_str(" { client: ::grpcio::Client::new(channel) }");
182 buf.push_str("}\n");
183 }
184
generate_client_methods(service: &Service, buf: &mut String)185 fn generate_client_methods(service: &Service, buf: &mut String) {
186 for method in &service.methods {
187 generate_client_method(&service.name, method, buf);
188 }
189 }
190
generate_client_method(service_name: &str, method: &Method, buf: &mut String)191 fn generate_client_method(service_name: &str, method: &Method, buf: &mut String) {
192 let name = &format!(
193 "METHOD_{}_{}",
194 to_snake_case(service_name).to_uppercase(),
195 method.name.to_uppercase()
196 );
197 match MethodType::from_method(method) {
198 MethodType::Unary => {
199 ClientMethod::new(
200 &method.name,
201 true,
202 Some(&method.input_type),
203 false,
204 vec![&method.output_type],
205 "unary_call",
206 name,
207 )
208 .generate(buf);
209 ClientMethod::new(
210 &method.name,
211 false,
212 Some(&method.input_type),
213 false,
214 vec![&method.output_type],
215 "unary_call",
216 name,
217 )
218 .generate(buf);
219 ClientMethod::new(
220 &method.name,
221 true,
222 Some(&method.input_type),
223 true,
224 vec![&format!(
225 "{}<{}>",
226 fq_grpc("ClientUnaryReceiver"),
227 method.output_type
228 )],
229 "unary_call",
230 name,
231 )
232 .generate(buf);
233 ClientMethod::new(
234 &method.name,
235 false,
236 Some(&method.input_type),
237 true,
238 vec![&format!(
239 "{}<{}>",
240 fq_grpc("ClientUnaryReceiver"),
241 method.output_type
242 )],
243 "unary_call",
244 name,
245 )
246 .generate(buf);
247 }
248 MethodType::ClientStreaming => {
249 ClientMethod::new(
250 &method.name,
251 true,
252 None,
253 false,
254 vec![
255 &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
256 &format!(
257 "{}<{}>",
258 fq_grpc("ClientCStreamReceiver"),
259 method.output_type
260 ),
261 ],
262 "client_streaming",
263 name,
264 )
265 .generate(buf);
266 ClientMethod::new(
267 &method.name,
268 false,
269 None,
270 false,
271 vec![
272 &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
273 &format!(
274 "{}<{}>",
275 fq_grpc("ClientCStreamReceiver"),
276 method.output_type
277 ),
278 ],
279 "client_streaming",
280 name,
281 )
282 .generate(buf);
283 }
284 MethodType::ServerStreaming => {
285 ClientMethod::new(
286 &method.name,
287 true,
288 Some(&method.input_type),
289 false,
290 vec![&format!(
291 "{}<{}>",
292 fq_grpc("ClientSStreamReceiver"),
293 method.output_type
294 )],
295 "server_streaming",
296 name,
297 )
298 .generate(buf);
299 ClientMethod::new(
300 &method.name,
301 false,
302 Some(&method.input_type),
303 false,
304 vec![&format!(
305 "{}<{}>",
306 fq_grpc("ClientSStreamReceiver"),
307 method.output_type
308 )],
309 "server_streaming",
310 name,
311 )
312 .generate(buf);
313 }
314 MethodType::Duplex => {
315 ClientMethod::new(
316 &method.name,
317 true,
318 None,
319 false,
320 vec![
321 &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
322 &format!(
323 "{}<{}>",
324 fq_grpc("ClientDuplexReceiver"),
325 method.output_type
326 ),
327 ],
328 "duplex_streaming",
329 name,
330 )
331 .generate(buf);
332 ClientMethod::new(
333 &method.name,
334 false,
335 None,
336 false,
337 vec![
338 &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
339 &format!(
340 "{}<{}>",
341 fq_grpc("ClientDuplexReceiver"),
342 method.output_type
343 ),
344 ],
345 "duplex_streaming",
346 name,
347 )
348 .generate(buf);
349 }
350 }
351 }
352
353 #[derive(new)]
354 struct ClientMethod<'a> {
355 method_name: &'a str,
356 opt: bool,
357 request: Option<&'a str>,
358 r#async: bool,
359 result_types: Vec<&'a str>,
360 inner_method_name: &'a str,
361 data_name: &'a str,
362 }
363
364 impl<'a> ClientMethod<'a> {
generate(&self, buf: &mut String)365 fn generate(&self, buf: &mut String) {
366 buf.push_str("pub fn ");
367
368 buf.push_str(self.method_name);
369 if self.r#async {
370 buf.push_str("_async");
371 }
372 if self.opt {
373 buf.push_str("_opt");
374 }
375
376 buf.push_str("(&self");
377 if let Some(req) = self.request {
378 buf.push_str(", req: &");
379 buf.push_str(req);
380 }
381 if self.opt {
382 buf.push_str(", opt: ");
383 buf.push_str(&fq_grpc("CallOption"));
384 }
385 buf.push_str(") -> ");
386
387 buf.push_str(&fq_grpc("Result"));
388 buf.push('<');
389 if self.result_types.len() != 1 {
390 buf.push('(');
391 }
392 for rt in &self.result_types {
393 buf.push_str(rt);
394 buf.push(',');
395 }
396 if self.result_types.len() != 1 {
397 buf.push(')');
398 }
399 buf.push_str("> { ");
400 if self.opt {
401 self.generate_inner_body(buf);
402 } else {
403 self.generate_opt_body(buf);
404 }
405 buf.push_str(" }\n");
406 }
407
408 // Method delegates to the `_opt` version of the method.
generate_opt_body(&self, buf: &mut String)409 fn generate_opt_body(&self, buf: &mut String) {
410 buf.push_str("self.");
411 buf.push_str(self.method_name);
412 if self.r#async {
413 buf.push_str("_async");
414 }
415 buf.push_str("_opt(");
416 if self.request.is_some() {
417 buf.push_str("req, ");
418 }
419 buf.push_str(&fq_grpc("CallOption::default()"));
420 buf.push(')');
421 }
422
423 // Method delegates to the inner client.
generate_inner_body(&self, buf: &mut String)424 fn generate_inner_body(&self, buf: &mut String) {
425 buf.push_str("self.client.");
426 buf.push_str(self.inner_method_name);
427 if self.r#async {
428 buf.push_str("_async");
429 }
430 buf.push_str("(&");
431 buf.push_str(self.data_name);
432 if self.request.is_some() {
433 buf.push_str(", req");
434 }
435 buf.push_str(", opt)");
436 }
437 }
438
generate_spawn(buf: &mut String)439 fn generate_spawn(buf: &mut String) {
440 buf.push_str(
441 "pub fn spawn<F>(&self, f: F) \
442 where F: ::futures::Future<Output = ()> + Send + 'static {\
443 self.client.spawn(f)\
444 }\n",
445 );
446 }
447
generate_server(service: &Service, buf: &mut String)448 fn generate_server(service: &Service, buf: &mut String) {
449 buf.push_str("pub trait ");
450 buf.push_str(&service.name);
451 buf.push_str(" {\n");
452 generate_server_methods(service, buf);
453 buf.push_str("}\n");
454
455 buf.push_str("pub fn create_");
456 buf.push_str(&to_snake_case(&service.name));
457 buf.push_str("<S: ");
458 buf.push_str(&service.name);
459 buf.push_str(" + Send + Clone + 'static>(s: S) -> ");
460 buf.push_str(&fq_grpc("Service"));
461 buf.push_str(" {\n");
462 buf.push_str("let mut builder = ::grpcio::ServiceBuilder::new();\n");
463
464 for method in &service.methods[0..service.methods.len() - 1] {
465 buf.push_str("let mut instance = s.clone();\n");
466 generate_method_bind(&service.name, method, buf);
467 }
468
469 buf.push_str("let mut instance = s;\n");
470 generate_method_bind(
471 &service.name,
472 &service.methods[service.methods.len() - 1],
473 buf,
474 );
475
476 buf.push_str("builder.build()\n");
477 buf.push_str("}\n");
478 }
479
generate_server_methods(service: &Service, buf: &mut String)480 fn generate_server_methods(service: &Service, buf: &mut String) {
481 for method in &service.methods {
482 let method_type = MethodType::from_method(method);
483 let request_arg = match method_type {
484 MethodType::Unary | MethodType::ServerStreaming => {
485 format!("req: {}", method.input_type)
486 }
487 MethodType::ClientStreaming | MethodType::Duplex => format!(
488 "stream: {}<{}>",
489 fq_grpc("RequestStream"),
490 method.input_type
491 ),
492 };
493 let response_type = match method_type {
494 MethodType::Unary => "UnarySink",
495 MethodType::ClientStreaming => "ClientStreamingSink",
496 MethodType::ServerStreaming => "ServerStreamingSink",
497 MethodType::Duplex => "DuplexSink",
498 };
499 generate_server_method(method, &request_arg, response_type, buf);
500 }
501 }
502
generate_server_method( method: &Method, request_arg: &str, response_type: &str, buf: &mut String, )503 fn generate_server_method(
504 method: &Method,
505 request_arg: &str,
506 response_type: &str,
507 buf: &mut String,
508 ) {
509 buf.push_str("fn ");
510 buf.push_str(&method.name);
511 buf.push_str("(&mut self, ctx: ");
512 buf.push_str(&fq_grpc("RpcContext"));
513 buf.push_str(", _");
514 buf.push_str(request_arg);
515 buf.push_str(", sink: ");
516 buf.push_str(&fq_grpc(response_type));
517 buf.push('<');
518 buf.push_str(&method.output_type);
519 buf.push('>');
520 buf.push_str(") { grpcio::unimplemented_call!(ctx, sink) }\n");
521 }
522
generate_method_bind(service_name: &str, method: &Method, buf: &mut String)523 fn generate_method_bind(service_name: &str, method: &Method, buf: &mut String) {
524 let add_name = match MethodType::from_method(method) {
525 MethodType::Unary => "add_unary_handler",
526 MethodType::ClientStreaming => "add_client_streaming_handler",
527 MethodType::ServerStreaming => "add_server_streaming_handler",
528 MethodType::Duplex => "add_duplex_streaming_handler",
529 };
530
531 buf.push_str("builder = builder.");
532 buf.push_str(add_name);
533 buf.push_str("(&");
534 buf.push_str(&const_method_name(service_name, method));
535 buf.push_str(", move |ctx, req, resp| instance.");
536 buf.push_str(&method.name);
537 buf.push_str("(ctx, req, resp));\n");
538 }
539
protoc_gen_grpc_rust_main()540 pub fn protoc_gen_grpc_rust_main() {
541 let mut args = env::args();
542 args.next();
543 let (mut protos, mut includes, mut out_dir): (Vec<_>, Vec<_>, _) = Default::default();
544 for arg in args {
545 if let Some(value) = arg.strip_prefix("--protos=") {
546 protos.extend(value.split(",").map(|s| s.to_string()));
547 } else if let Some(value) = arg.strip_prefix("--includes=") {
548 includes.extend(value.split(",").map(|s| s.to_string()));
549 } else if let Some(value) = arg.strip_prefix("--out-dir=") {
550 out_dir = value.to_string();
551 }
552 }
553 if protos.is_empty() {
554 panic!("should at least specify protos to generate");
555 }
556 compile_protos(&protos, &includes, &out_dir).unwrap();
557 }
558