1#!/usr/bin/env python3 2 3# Copyright 2022 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Custom mmi2grpc gRPC compiler.""" 18 19import os 20import sys 21 22from typing import Dict, List, Optional, Set, Tuple, Union 23 24from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse 25from google.protobuf.descriptor import ( 26 FieldDescriptor 27) 28from google.protobuf.descriptor_pb2 import ( 29 FileDescriptorProto, 30 EnumDescriptorProto, 31 DescriptorProto, 32 ServiceDescriptorProto, 33 MethodDescriptorProto, 34 FieldDescriptorProto, 35) 36 37_REQUEST = CodeGeneratorRequest.FromString(sys.stdin.buffer.read()) 38 39 40def find_type_in_file(proto_file: FileDescriptorProto, type_name: str) -> Optional[Union[DescriptorProto, EnumDescriptorProto]]: 41 for enum in proto_file.enum_type: 42 if enum.name == type_name: 43 return enum 44 for message in proto_file.message_type: 45 if message.name == type_name: 46 return message 47 return None 48 49 50def find_type(package: str, type_name: str) -> Tuple[FileDescriptorProto, Union[DescriptorProto, EnumDescriptorProto]]: 51 for file in _REQUEST.proto_file: 52 if file.package == package and (type := find_type_in_file(file, type_name)): 53 return file, type 54 raise Exception(f'Type {package}.{type_name} not found') 55 56 57def add_import(imports: List[str], import_str: str) -> None: 58 if not import_str in imports: 59 imports.append(import_str) 60 61 62def import_type(imports: List[str], type: str, local: Optional[FileDescriptorProto]) -> Tuple[str, Union[DescriptorProto, EnumDescriptorProto], str]: 63 package = type[1:type.rindex('.')] 64 type_name = type[type.rindex('.')+1:] 65 file, desc = find_type(package, type_name) 66 if file == local: 67 return f'{type_name}', desc, '' 68 python_path = file.name.replace('.proto', '').replace('/', '.') 69 module_path = python_path[:python_path.rindex('.')] 70 module_name = python_path[python_path.rindex('.')+1:] + '_pb2' 71 add_import(imports, f'from {module_path} import {module_name}') 72 dft_import = '' 73 if isinstance(desc, EnumDescriptorProto): 74 dft_import = f'from {module_path}.{module_name} import {desc.value[0].name}' 75 return f'{module_name}.{type_name}', desc, dft_import 76 77 78def collect_type(imports: List[str], parent: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[str, str, str]: 79 dft: str 80 dft_import: str = '' 81 if field.type == FieldDescriptor.TYPE_BYTES: 82 type = 'bytes' 83 dft = 'b\'\'' 84 elif field.type == FieldDescriptor.TYPE_STRING: 85 type = 'str' 86 dft = '\'\'' 87 elif field.type == FieldDescriptor.TYPE_BOOL: 88 type = 'bool' 89 dft = 'False' 90 elif field.type in [ 91 FieldDescriptor.TYPE_FLOAT, 92 FieldDescriptor.TYPE_DOUBLE 93 ]: 94 type = 'float' 95 dft = '0.0' 96 elif field.type in [ 97 FieldDescriptor.TYPE_INT64, 98 FieldDescriptor.TYPE_UINT64, 99 FieldDescriptor.TYPE_INT32, 100 FieldDescriptor.TYPE_FIXED64, 101 FieldDescriptor.TYPE_FIXED32, 102 FieldDescriptor.TYPE_UINT32, 103 FieldDescriptor.TYPE_SFIXED32, 104 FieldDescriptor.TYPE_SFIXED64, 105 FieldDescriptor.TYPE_SINT32, 106 FieldDescriptor.TYPE_SINT64 107 ]: 108 type = 'int' 109 dft = '0' 110 elif field.type in [FieldDescriptor.TYPE_ENUM, FieldDescriptor.TYPE_MESSAGE]: 111 parts = field.type_name.split(f".{parent.name}.", 2) 112 if len(parts) == 2: 113 type = parts[1] 114 for nested_type in parent.nested_type: 115 if nested_type.name == type: 116 assert nested_type.options.map_entry 117 assert field.label == FieldDescriptor.LABEL_REPEATED 118 key_type, _, _ = collect_type(imports, nested_type, nested_type.field[0], local) 119 val_type, _, _ = collect_type(imports, nested_type, nested_type.field[1], local) 120 add_import(imports, 'from typing import Dict') 121 return f'Dict[{key_type}, {val_type}]', '{}', '' 122 type, desc, enum_dft = import_type(imports, field.type_name, local) 123 if isinstance(desc, EnumDescriptorProto): 124 dft_import = enum_dft 125 dft = desc.value[0].name 126 else: 127 dft = f'{type}()' 128 else: 129 raise Exception(f'TODO: {field}') 130 131 if field.label == FieldDescriptor.LABEL_REPEATED: 132 add_import(imports, 'from typing import List') 133 type = f'List[{type}]' 134 dft = '[]' 135 136 return type, dft, dft_import 137 138 139def collect_field(imports: List[str], message: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[Optional[int], str, str, str, str]: 140 type, dft, dft_import = collect_type(imports, message, field, local) 141 oneof_index = field.oneof_index if 'oneof_index' in f'{field}' else None 142 return oneof_index, field.name, type, dft, dft_import 143 144 145def collect_message(imports: List[str], message: DescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[ 146 List[Tuple[str, str, str]], 147 Dict[str, list[Tuple[str, str]]], 148]: 149 fields: List[Tuple[str, str, str]] = [] 150 oneof: Dict[str, list[Tuple[str, str]]] = {} 151 152 for field in message.field: 153 idx, name, type, dft, dft_import = collect_field(imports, message, field, local) 154 if idx is not None: 155 oneof_name = message.oneof_decl[idx].name 156 oneof.setdefault(oneof_name, []) 157 oneof[oneof_name].append((name, type)) 158 else: 159 add_import(imports, dft_import) 160 fields.append((name, type, dft)) 161 162 for oneof_name, oneof_fields in oneof.items(): 163 for name, type in oneof_fields: 164 add_import(imports, 'from typing import Optional') 165 fields.append((name, f'Optional[{type}]', 'None')) 166 167 return fields, oneof 168 169 170def generate_enum(imports: List[str], file: FileDescriptorProto, enum: EnumDescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]: 171 res.append(CodeGeneratorResponse.File( 172 name=file.name.replace('.proto', '_pb2.py'), 173 insertion_point=f'module_scope', 174 content=f'class {enum.name}: ...\n\n' 175 )) 176 add_import(imports, 'from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper') 177 return [ 178 f'class {enum.name}(int, EnumTypeWrapper):', 179 f' pass', 180 f'', 181 *[f'{value.name}: {enum.name}' for value in enum.value], 182 '' 183 ] 184 185 186def generate_message(imports: List[str], file: FileDescriptorProto, message: DescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]: 187 nested_message_lines: List[str] = [] 188 message_lines: List[str] = [f'class {message.name}(Message):'] 189 190 add_import(imports, 'from google.protobuf.message import Message') 191 fields, oneof = collect_message(imports, message, file) 192 193 for (name, type, _) in fields: 194 message_lines.append(f' {name}: {type}') 195 196 args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in fields]) 197 if args: args = ', ' + args 198 message_lines.extend([ 199 f'', 200 f' def __init__(self{args}) -> None: ...', 201 f'' 202 ]) 203 204 for oneof_name, oneof_fields in oneof.items(): 205 literals: str = ', '.join((f'Literal[\'{name}\']' for name, _ in oneof_fields)) 206 types: Set[str] = set((type for _, type in oneof_fields)) 207 if len(types) == 1: 208 type = 'Optional[' + types.pop() + ']' 209 else: 210 types.add('None') 211 type = 'Union[' + ', '.join(types) + ']' 212 213 nested_message_lines.extend([ 214 f'class {message.name}_{oneof_name}_dict(TypedDict, total=False):', 215 '\n'.join([f' {name}: {type}' for name, type in oneof_fields]), 216 f'', 217 ]) 218 219 add_import(imports, 'from typing import Union') 220 add_import(imports, 'from typing_extensions import TypedDict') 221 add_import(imports, 'from typing_extensions import Literal') 222 message_lines.extend([ 223 f' @property', 224 f' def {oneof_name}(self) -> {type}: ...' 225 f'', 226 f' def {oneof_name}_variant(self) -> Union[{literals}, None]: ...' 227 f'', 228 f' def {oneof_name}_asdict(self) -> {message.name}_{oneof_name}_dict: ...', 229 f'', 230 ]) 231 232 return_variant = '\n '.join([f'if variant == \'{name}\': return unwrap(self.{name})' for name, _ in oneof_fields]) 233 return_asdict = '\n '.join([f'if variant == \'{name}\': return {{\'{name}\': unwrap(self.{name})}} # type: ignore' for name, _ in oneof_fields]) 234 if return_variant: return_variant += '\n ' 235 if return_asdict: return_asdict += '\n ' 236 237 res.append(CodeGeneratorResponse.File( 238 name=file.name.replace('.proto', '_pb2.py'), 239 insertion_point=f'module_scope', 240 content=f""" 241def _{message.name}_{oneof_name}(self: {message.name}): 242 variant = self.{oneof_name}_variant() 243 if variant is None: return None 244 {return_variant}raise Exception('Field `{oneof_name}` not found.') 245 246def _{message.name}_{oneof_name}_variant(self: {message.name}): 247 return self.WhichOneof('{oneof_name}') # type: ignore 248 249def _{message.name}_{oneof_name}_asdict(self: {message.name}): 250 variant = self.{oneof_name}_variant() 251 if variant is None: return {{}} 252 {return_asdict}raise Exception('Field `{oneof_name}` not found.') 253 254setattr({message.name}, '{oneof_name}', property(_{message.name}_{oneof_name})) 255setattr({message.name}, '{oneof_name}_variant', _{message.name}_{oneof_name}_variant) 256setattr({message.name}, '{oneof_name}_asdict', _{message.name}_{oneof_name}_asdict) 257""")) 258 259 return message_lines + nested_message_lines 260 261 262def generate_service_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, method: MethodDescriptorProto, sync: bool = True) -> List[str]: 263 input_mode = 'stream' if method.client_streaming else 'unary' 264 output_mode = 'stream' if method.server_streaming else 'unary' 265 266 input_type, input_msg, _ = import_type(imports, method.input_type, None) 267 output_type, _, _ = import_type(imports, method.output_type, None) 268 269 input_type_pb2, _, _ = import_type(imports, method.input_type, None) 270 output_type_pb2, _, _ = import_type(imports, method.output_type, None) 271 272 if output_mode == 'stream': 273 if input_mode == 'stream': 274 output_type_hint = f'StreamStream[{input_type}, {output_type}]' 275 if sync: 276 add_import(imports, f'from ._utils import Sender') 277 add_import(imports, f'from ._utils import Stream') 278 add_import(imports, f'from ._utils import StreamStream') 279 else: 280 add_import(imports, f'from ._utils import AioSender as Sender') 281 add_import(imports, f'from ._utils import AioStream as Stream') 282 add_import(imports, f'from ._utils import AioStreamStream as StreamStream') 283 else: 284 output_type_hint = f'Stream[{output_type}]' 285 if sync: 286 add_import(imports, f'from ._utils import Stream') 287 else: 288 add_import(imports, f'from ._utils import AioStream as Stream') 289 else: 290 output_type_hint = output_type if sync else f'Awaitable[{output_type}]' 291 if not sync: add_import(imports, f'from typing import Awaitable') 292 293 if input_mode == 'stream' and output_mode == 'stream': 294 add_import(imports, f'from typing import Optional') 295 return ( 296 f'def {method.name}(self, timeout: Optional[float] = None) -> {output_type_hint}:\n' 297 f' tx: Sender[{input_type}] = Sender()\n' 298 f' rx: Stream[{output_type}] = self.channel.{input_mode}_{output_mode}( # type: ignore\n' 299 f" '/{file.package}.{service.name}/{method.name}',\n" 300 f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n' 301 f' response_deserializer={output_type_pb2}.FromString # type: ignore\n' 302 f' )(tx)\n' 303 f' return StreamStream(tx, rx)' 304 ).split('\n') 305 if input_mode == 'stream': 306 iterator_type = 'Iterator' if sync else 'AsyncIterator' 307 add_import(imports, f'from typing import {iterator_type}') 308 add_import(imports, f'from typing import Optional') 309 return ( 310 f'def {method.name}(self, iterator: {iterator_type}[{input_type}], timeout: Optional[float] = None) -> {output_type_hint}:\n' 311 f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n' 312 f" '/{file.package}.{service.name}/{method.name}',\n" 313 f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n' 314 f' response_deserializer={output_type_pb2}.FromString # type: ignore\n' 315 f' )(iterator)' 316 ).split('\n') 317 else: 318 add_import(imports, f'from typing import Optional') 319 assert isinstance(input_msg, DescriptorProto) 320 input_fields, _ = collect_message(imports, input_msg, None) 321 args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in input_fields]) 322 args_name = ', '.join([f'{name}={name}' for name, _, _ in input_fields]) 323 if args: args = ', ' + args 324 return ( 325 f'def {method.name}(self{args}, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = None) -> {output_type_hint}:\n' 326 f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n' 327 f" '/{file.package}.{service.name}/{method.name}',\n" 328 f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n' 329 f' response_deserializer={output_type_pb2}.FromString # type: ignore\n' 330 f' )({input_type_pb2}({args_name}), wait_for_ready=wait_for_ready, timeout=timeout) # type: ignore' 331 ).split('\n') 332 333 334def generate_service(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]: 335 methods = '\n\n '.join([ 336 '\n '.join( 337 generate_service_method(imports, file, service, method, sync) 338 ) for method in service.method 339 ]) 340 channel_type = 'grpc.Channel' if sync else 'grpc.aio.Channel' 341 return ( 342 f'class {service.name}:\n' 343 f' channel: {channel_type}\n' 344 f'\n' 345 f' def __init__(self, channel: {channel_type}) -> None:\n' 346 f' self.channel = channel\n' 347 f'\n' 348 f' {methods}\n' 349 ).split('\n') 350 351 352def generate_servicer_method(imports: List[str], method: MethodDescriptorProto, sync: bool = True) -> List[str]: 353 input_mode = 'stream' if method.client_streaming else 'unary' 354 output_mode = 'stream' if method.server_streaming else 'unary' 355 356 input_type, _, _ = import_type(imports, method.input_type, None) 357 output_type, _, _ = import_type(imports, method.output_type, None) 358 359 output_type_hint = output_type 360 if output_mode == 'stream': 361 if sync: 362 output_type_hint = f'Generator[{output_type}, None, None]' 363 add_import(imports, f'from typing import Generator') 364 else: 365 output_type_hint = f'AsyncGenerator[{output_type}, None]' 366 add_import(imports, f'from typing import AsyncGenerator') 367 368 iterator_type = 'Iterator' if sync else 'AsyncIterator' 369 370 if input_mode == 'stream': 371 iterator_type = 'Iterator' if sync else 'AsyncIterator' 372 add_import(imports, f'from typing import {iterator_type}') 373 lines = (('' if sync else 'async ') + ( 374 f'def {method.name}(self, request: {iterator_type}[{input_type}], context: grpc.ServicerContext) -> {output_type_hint}:\n' 375 f' context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore\n' 376 f' context.set_details("Method not implemented!") # type: ignore\n' 377 f' raise NotImplementedError("Method not implemented!")' 378 )).split('\n') 379 else: 380 lines = (('' if sync else 'async ') + ( 381 f'def {method.name}(self, request: {input_type}, context: grpc.ServicerContext) -> {output_type_hint}:\n' 382 f' context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore\n' 383 f' context.set_details("Method not implemented!") # type: ignore\n' 384 f' raise NotImplementedError("Method not implemented!")' 385 )).split('\n') 386 if output_mode == 'stream': 387 lines.append(f' yield {output_type}() # no-op: to make the linter happy') 388 return lines 389 390 391def generate_servicer(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]: 392 methods = '\n\n '.join([ 393 '\n '.join( 394 generate_servicer_method(imports, method, sync) 395 ) for method in service.method 396 ]) 397 if not methods: 398 methods = 'pass' 399 return ( 400 f'class {service.name}Servicer:\n' 401 f' {methods}\n' 402 ).split('\n') 403 404 405def generate_rpc_method_handler(imports: List[str], method: MethodDescriptorProto) -> List[str]: 406 input_mode = 'stream' if method.client_streaming else 'unary' 407 output_mode = 'stream' if method.server_streaming else 'unary' 408 409 input_type, _, _ = import_type(imports, method.input_type, None) 410 output_type, _, _ = import_type(imports, method.output_type, None) 411 412 return ( 413 f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler( # type: ignore\n" 414 f' servicer.{method.name},\n' 415 f' request_deserializer={input_type}.FromString, # type: ignore\n' 416 f' response_serializer={output_type}.SerializeToString, # type: ignore\n' 417 f' ),\n' 418 ).split('\n') 419 420 421def generate_add_servicer_to_server_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]: 422 method_handlers = ' '.join([ 423 '\n '.join( 424 generate_rpc_method_handler(imports, method) 425 ) for method in service.method 426 ]) 427 server_type = 'grpc.Server' if sync else 'grpc.aio.Server' 428 return ( 429 f'def add_{service.name}Servicer_to_server(servicer: {service.name}Servicer, server: {server_type}) -> None:\n' 430 f' rpc_method_handlers = {{\n' 431 f' {method_handlers}\n' 432 f' }}\n' 433 f' generic_handler = grpc.method_handlers_generic_handler( # type: ignore\n' 434 f" '{file.package}.{service.name}', rpc_method_handlers)\n" 435 f' server.add_generic_rpc_handlers((generic_handler,)) # type: ignore\n' 436 ).split('\n') 437 438 439_HEADER = '''# Copyright 2022 Google LLC 440# 441# Licensed under the Apache License, Version 2.0 (the "License"); 442# you may not use this file except in compliance with the License. 443# You may obtain a copy of the License at 444# 445# https://www.apache.org/licenses/LICENSE-2.0 446# 447# Unless required by applicable law or agreed to in writing, software 448# distributed under the License is distributed on an "AS IS" BASIS, 449# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 450# See the License for the specific language governing permissions and 451# limitations under the License. 452 453"""Generated python gRPC interfaces.""" 454''' 455 456_UTILS_PY = f'''{_HEADER} 457 458import asyncio 459import queue 460import grpc 461import sys 462 463from typing import Any, AsyncIterable, AsyncIterator, Generic, Iterator, TypeVar 464 465 466_T_co = TypeVar('_T_co', covariant=True) 467_T = TypeVar('_T') 468 469 470class Stream(Iterator[_T_co], grpc.RpcContext): ... 471 472 473class AioStream(AsyncIterable[_T_co], grpc.RpcContext): ... 474 475 476class Sender(Iterator[_T]): 477 if sys.version_info >= (3, 8): 478 _inner: queue.Queue[_T] 479 else: 480 _inner: queue.Queue 481 482 def __init__(self) -> None: 483 self._inner = queue.Queue() 484 485 def __iter__(self) -> Iterator[_T]: 486 return self 487 488 def __next__(self) -> _T: 489 return self._inner.get() 490 491 def send(self, item: _T) -> None: 492 self._inner.put(item) 493 494 495class AioSender(AsyncIterator[_T]): 496 if sys.version_info >= (3, 8): 497 _inner: asyncio.Queue[_T] 498 else: 499 _inner: asyncio.Queue 500 501 def __init__(self) -> None: 502 self._inner = asyncio.Queue() 503 504 def __iter__(self) -> AsyncIterator[_T]: 505 return self 506 507 async def __anext__(self) -> _T: 508 return await self._inner.get() 509 510 async def send(self, item: _T) -> None: 511 await self._inner.put(item) 512 513 def send_nowait(self, item: _T) -> None: 514 self._inner.put_nowait(item) 515 516 517class StreamStream(Generic[_T, _T_co], Iterator[_T_co], grpc.RpcContext): 518 _sender: Sender[_T] 519 _receiver: Stream[_T_co] 520 521 def __init__(self, sender: Sender[_T], receiver: Stream[_T_co]) -> None: 522 self._sender = sender 523 self._receiver = receiver 524 525 def send(self, item: _T) -> None: 526 self._sender.send(item) 527 528 def __iter__(self) -> Iterator[_T_co]: 529 return self._receiver.__iter__() 530 531 def __next__(self) -> _T_co: 532 return self._receiver.__next__() 533 534 def is_active(self) -> bool: 535 return self._receiver.is_active() # type: ignore 536 537 def time_remaining(self) -> float: 538 return self._receiver.time_remaining() # type: ignore 539 540 def cancel(self) -> None: 541 self._receiver.cancel() # type: ignore 542 543 def add_callback(self, callback: Any) -> None: 544 self._receiver.add_callback(callback) # type: ignore 545 546 547class AioStreamStream(Generic[_T, _T_co], AsyncIterator[_T_co], grpc.RpcContext): 548 _sender: AioSender[_T] 549 _receiver: AioStream[_T_co] 550 551 def __init__(self, sender: AioSender[_T], receiver: AioStream[_T_co]) -> None: 552 self._sender = sender 553 self._receiver = receiver 554 555 def __aiter__(self) -> AsyncIterator[_T_co]: 556 return self._receiver.__aiter__() 557 558 async def __anext__(self) -> _T_co: 559 return await self._receiver.__aiter__().__anext__() 560 561 async def send(self, item: _T) -> None: 562 await self._sender.send(item) 563 564 def send_nowait(self, item: _T) -> None: 565 self._sender.send_nowait(item) 566 567 def is_active(self) -> bool: 568 return self._receiver.is_active() # type: ignore 569 570 def time_remaining(self) -> float: 571 return self._receiver.time_remaining() # type: ignore 572 573 def cancel(self) -> None: 574 self._receiver.cancel() # type: ignore 575 576 def add_callback(self, callback: Any) -> None: 577 self._receiver.add_callback(callback) # type: ignore 578''' 579 580 581_FILES: List[CodeGeneratorResponse.File] = [] 582_UTILS_FILES: Set[str] = set() 583 584 585for file_name in _REQUEST.file_to_generate: 586 file: FileDescriptorProto = next(filter(lambda x: x.name == file_name, _REQUEST.proto_file)) 587 588 _FILES.append(CodeGeneratorResponse.File( 589 name=file.name.replace('.proto', '_pb2.py'), 590 insertion_point=f'module_scope', 591 content='def unwrap(x):\n assert x\n return x\n' 592 )) 593 594 pyi_imports: List[str] = [] 595 grpc_imports: List[str] = ['import grpc'] 596 grpc_aio_imports: List[str] = ['import grpc', 'import grpc.aio'] 597 598 enums = '\n'.join(sum([generate_enum(pyi_imports, file, enum, _FILES) for enum in file.enum_type], [])) 599 messages = '\n'.join(sum([generate_message(pyi_imports, file, message, _FILES) for message in file.message_type], [])) 600 601 services = '\n'.join(sum([generate_service(grpc_imports, file, service) for service in file.service], [])) 602 aio_services = '\n'.join(sum([generate_service(grpc_aio_imports, file, service, False) for service in file.service], [])) 603 604 servicers = '\n'.join(sum([generate_servicer(grpc_imports, file, service) for service in file.service], [])) 605 aio_servicers = '\n'.join(sum([generate_servicer(grpc_aio_imports, file, service, False) for service in file.service], [])) 606 607 add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_imports, file, service) for service in file.service], [])) 608 aio_add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_aio_imports, file, service, False) for service in file.service], [])) 609 610 pyi_imports.sort() 611 grpc_imports.sort() 612 grpc_aio_imports.sort() 613 614 pyi_imports_str: str = '\n'.join(pyi_imports) 615 grpc_imports_str: str = '\n'.join(grpc_imports) 616 grpc_aio_imports_str: str = '\n'.join(grpc_aio_imports) 617 618 utils_filename = file_name.replace(os.path.basename(file_name), '_utils.py') 619 if utils_filename not in _UTILS_FILES: 620 _UTILS_FILES.add(utils_filename) 621 _FILES.extend([ 622 CodeGeneratorResponse.File( 623 name=utils_filename, 624 content=_UTILS_PY, 625 ) 626 ]) 627 628 _FILES.extend([ 629 CodeGeneratorResponse.File( 630 name=file.name.replace('.proto', '_pb2.pyi'), 631 content=f'{_HEADER}\n\n{pyi_imports_str}\n\n{enums}\n\n{messages}\n' 632 ), 633 CodeGeneratorResponse.File( 634 name=file_name.replace('.proto', '_grpc.py'), 635 content=f'{_HEADER}\n\n{grpc_imports_str}\n\n{services}\n\n{servicers}\n\n{add_servicer_methods}' 636 ), 637 CodeGeneratorResponse.File( 638 name=file_name.replace('.proto', '_grpc_aio.py'), 639 content=f'{_HEADER}\n\n{grpc_aio_imports_str}\n\n{aio_services}\n\n{aio_servicers}\n\n{aio_add_servicer_methods}' 640 ) 641 ]) 642 643 644sys.stdout.buffer.write(CodeGeneratorResponse(file=_FILES).SerializeToString()) 645