• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2022 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Custom mmi2grpc gRPC compiler."""
18
19import os
20import sys
21
22from typing import Dict, List, Optional, Set, Tuple, Union
23
24from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse
25from google.protobuf.descriptor import (
26    FieldDescriptor
27)
28from google.protobuf.descriptor_pb2 import (
29    FileDescriptorProto,
30    EnumDescriptorProto,
31    DescriptorProto,
32    ServiceDescriptorProto,
33    MethodDescriptorProto,
34    FieldDescriptorProto,
35)
36
37_REQUEST = CodeGeneratorRequest.FromString(sys.stdin.buffer.read())
38
39
40def find_type_in_file(proto_file: FileDescriptorProto, type_name: str) -> Optional[Union[DescriptorProto, EnumDescriptorProto]]:
41    for enum in  proto_file.enum_type:
42        if enum.name == type_name:
43            return enum
44    for message in  proto_file.message_type:
45        if message.name == type_name:
46            return message
47    return None
48
49
50def find_type(package: str, type_name: str) -> Tuple[FileDescriptorProto, Union[DescriptorProto, EnumDescriptorProto]]:
51    for file in _REQUEST.proto_file:
52        if file.package == package and (type := find_type_in_file(file, type_name)):
53            return file, type
54    raise Exception(f'Type {package}.{type_name} not found')
55
56
57def add_import(imports: List[str], import_str: str) -> None:
58    if not import_str in imports:
59        imports.append(import_str)
60
61
62def import_type(imports: List[str], type: str, local: Optional[FileDescriptorProto]) -> Tuple[str, Union[DescriptorProto, EnumDescriptorProto], str]:
63    package = type[1:type.rindex('.')]
64    type_name = type[type.rindex('.')+1:]
65    file, desc = find_type(package, type_name)
66    if file == local:
67        return f'{type_name}', desc, ''
68    python_path = file.name.replace('.proto', '').replace('/', '.')
69    module_path = python_path[:python_path.rindex('.')]
70    module_name = python_path[python_path.rindex('.')+1:] + '_pb2'
71    add_import(imports, f'from {module_path} import {module_name}')
72    dft_import = ''
73    if isinstance(desc, EnumDescriptorProto):
74        dft_import = f'from {module_path}.{module_name} import {desc.value[0].name}'
75    return f'{module_name}.{type_name}', desc, dft_import
76
77
78def collect_type(imports: List[str], parent: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[str, str, str]:
79    dft: str
80    dft_import: str = ''
81    if field.type == FieldDescriptor.TYPE_BYTES:
82        type = 'bytes'
83        dft = 'b\'\''
84    elif field.type == FieldDescriptor.TYPE_STRING:
85        type = 'str'
86        dft = '\'\''
87    elif field.type == FieldDescriptor.TYPE_BOOL:
88        type = 'bool'
89        dft = 'False'
90    elif field.type in [
91        FieldDescriptor.TYPE_FLOAT,
92        FieldDescriptor.TYPE_DOUBLE
93    ]:
94        type = 'float'
95        dft = '0.0'
96    elif field.type in [
97        FieldDescriptor.TYPE_INT64,
98        FieldDescriptor.TYPE_UINT64,
99        FieldDescriptor.TYPE_INT32,
100        FieldDescriptor.TYPE_FIXED64,
101        FieldDescriptor.TYPE_FIXED32,
102        FieldDescriptor.TYPE_UINT32,
103        FieldDescriptor.TYPE_SFIXED32,
104        FieldDescriptor.TYPE_SFIXED64,
105        FieldDescriptor.TYPE_SINT32,
106        FieldDescriptor.TYPE_SINT64
107    ]:
108        type = 'int'
109        dft = '0'
110    elif field.type in [FieldDescriptor.TYPE_ENUM, FieldDescriptor.TYPE_MESSAGE]:
111        parts = field.type_name.split(f".{parent.name}.", 2)
112        if len(parts) == 2:
113            type = parts[1]
114            for nested_type in parent.nested_type:
115                if nested_type.name == type:
116                    assert nested_type.options.map_entry
117                    assert field.label == FieldDescriptor.LABEL_REPEATED
118                    key_type, _, _ = collect_type(imports, nested_type, nested_type.field[0], local)
119                    val_type, _, _ = collect_type(imports, nested_type, nested_type.field[1], local)
120                    add_import(imports, 'from typing import Dict')
121                    return f'Dict[{key_type}, {val_type}]', '{}', ''
122        type, desc, enum_dft = import_type(imports, field.type_name, local)
123        if isinstance(desc, EnumDescriptorProto):
124            dft_import = enum_dft
125            dft = desc.value[0].name
126        else:
127            dft = f'{type}()'
128    else:
129        raise Exception(f'TODO: {field}')
130
131    if field.label == FieldDescriptor.LABEL_REPEATED:
132        add_import(imports, 'from typing import List')
133        type = f'List[{type}]'
134        dft = '[]'
135
136    return type, dft, dft_import
137
138
139def collect_field(imports: List[str], message: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[Optional[int], str, str, str, str]:
140    type, dft, dft_import = collect_type(imports, message, field, local)
141    oneof_index = field.oneof_index if 'oneof_index' in f'{field}' else None
142    return oneof_index, field.name, type, dft, dft_import
143
144
145def collect_message(imports: List[str], message: DescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[
146    List[Tuple[str, str, str]],
147    Dict[str, list[Tuple[str, str]]],
148]:
149    fields: List[Tuple[str, str, str]] = []
150    oneof: Dict[str, list[Tuple[str, str]]] = {}
151
152    for field in message.field:
153        idx, name, type, dft, dft_import = collect_field(imports, message, field, local)
154        if idx is not None:
155            oneof_name = message.oneof_decl[idx].name
156            oneof.setdefault(oneof_name, [])
157            oneof[oneof_name].append((name, type))
158        else:
159            add_import(imports, dft_import)
160            fields.append((name, type, dft))
161
162    for oneof_name, oneof_fields in oneof.items():
163        for name, type in oneof_fields:
164            add_import(imports, 'from typing import Optional')
165            fields.append((name, f'Optional[{type}]', 'None'))
166
167    return fields, oneof
168
169
170def generate_enum(imports: List[str], file: FileDescriptorProto, enum: EnumDescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]:
171    res.append(CodeGeneratorResponse.File(
172        name=file.name.replace('.proto', '_pb2.py'),
173        insertion_point=f'module_scope',
174        content=f'class {enum.name}: ...\n\n'
175    ))
176    add_import(imports, 'from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper')
177    return [
178        f'class {enum.name}(int, EnumTypeWrapper):',
179        f'  pass',
180        f'',
181        *[f'{value.name}: {enum.name}' for value in enum.value],
182        ''
183    ]
184
185
186def generate_message(imports: List[str], file: FileDescriptorProto, message: DescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]:
187    nested_message_lines: List[str] = []
188    message_lines: List[str] = [f'class {message.name}(Message):']
189
190    add_import(imports, 'from google.protobuf.message import Message')
191    fields, oneof = collect_message(imports, message, file)
192
193    for (name, type, _) in fields:
194        message_lines.append(f'  {name}: {type}')
195
196    args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in fields])
197    if args: args = ', ' + args
198    message_lines.extend([
199        f'',
200        f'  def __init__(self{args}) -> None: ...',
201        f''
202    ])
203
204    for oneof_name, oneof_fields in oneof.items():
205        literals: str = ', '.join((f'Literal[\'{name}\']' for name, _ in oneof_fields))
206        types: Set[str] = set((type for _, type in oneof_fields))
207        if len(types) == 1:
208            type = 'Optional[' + types.pop() + ']'
209        else:
210            types.add('None')
211            type = 'Union[' + ', '.join(types) + ']'
212
213        nested_message_lines.extend([
214            f'class {message.name}_{oneof_name}_dict(TypedDict, total=False):',
215            '\n'.join([f'  {name}: {type}' for name, type in oneof_fields]),
216            f'',
217        ])
218
219        add_import(imports, 'from typing import Union')
220        add_import(imports, 'from typing_extensions import TypedDict')
221        add_import(imports, 'from typing_extensions import Literal')
222        message_lines.extend([
223            f'  @property',
224            f'  def {oneof_name}(self) -> {type}: ...'
225            f'',
226            f'  def {oneof_name}_variant(self) -> Union[{literals}, None]: ...'
227            f'',
228            f'  def {oneof_name}_asdict(self) -> {message.name}_{oneof_name}_dict: ...',
229            f'',
230        ])
231
232        return_variant = '\n  '.join([f'if variant == \'{name}\': return unwrap(self.{name})' for name, _ in oneof_fields])
233        return_asdict = '\n  '.join([f'if variant == \'{name}\': return {{\'{name}\': unwrap(self.{name})}}  # type: ignore' for name, _ in oneof_fields])
234        if return_variant: return_variant += '\n  '
235        if return_asdict: return_asdict += '\n  '
236
237        res.append(CodeGeneratorResponse.File(
238            name=file.name.replace('.proto', '_pb2.py'),
239            insertion_point=f'module_scope',
240            content=f"""
241def _{message.name}_{oneof_name}(self: {message.name}):
242  variant = self.{oneof_name}_variant()
243  if variant is None: return None
244  {return_variant}raise Exception('Field `{oneof_name}` not found.')
245
246def _{message.name}_{oneof_name}_variant(self: {message.name}):
247  return self.WhichOneof('{oneof_name}')  # type: ignore
248
249def _{message.name}_{oneof_name}_asdict(self: {message.name}):
250  variant = self.{oneof_name}_variant()
251  if variant is None: return {{}}
252  {return_asdict}raise Exception('Field `{oneof_name}` not found.')
253
254setattr({message.name}, '{oneof_name}', property(_{message.name}_{oneof_name}))
255setattr({message.name}, '{oneof_name}_variant', _{message.name}_{oneof_name}_variant)
256setattr({message.name}, '{oneof_name}_asdict', _{message.name}_{oneof_name}_asdict)
257"""))
258
259    return message_lines + nested_message_lines
260
261
262def generate_service_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, method: MethodDescriptorProto, sync: bool = True) -> List[str]:
263    input_mode = 'stream' if method.client_streaming else 'unary'
264    output_mode = 'stream' if method.server_streaming else 'unary'
265
266    input_type, input_msg, _ = import_type(imports, method.input_type, None)
267    output_type, _, _ = import_type(imports, method.output_type, None)
268
269    input_type_pb2, _, _ = import_type(imports, method.input_type, None)
270    output_type_pb2, _, _ = import_type(imports, method.output_type, None)
271
272    if output_mode == 'stream':
273        if input_mode == 'stream':
274            output_type_hint = f'StreamStream[{input_type}, {output_type}]'
275            if sync:
276                add_import(imports, f'from ._utils import Sender')
277                add_import(imports, f'from ._utils import Stream')
278                add_import(imports, f'from ._utils import StreamStream')
279            else:
280                add_import(imports, f'from ._utils import AioSender as Sender')
281                add_import(imports, f'from ._utils import AioStream as Stream')
282                add_import(imports, f'from ._utils import AioStreamStream as StreamStream')
283        else:
284            output_type_hint = f'Stream[{output_type}]'
285            if sync:
286                add_import(imports, f'from ._utils import Stream')
287            else:
288                add_import(imports, f'from ._utils import AioStream as Stream')
289    else:
290        output_type_hint = output_type if sync else f'Awaitable[{output_type}]'
291        if not sync: add_import(imports, f'from typing import Awaitable')
292
293    if input_mode == 'stream' and output_mode == 'stream':
294        add_import(imports, f'from typing import Optional')
295        return (
296            f'def {method.name}(self, timeout: Optional[float] = None) -> {output_type_hint}:\n'
297            f'    tx: Sender[{input_type}] = Sender()\n'
298            f'    rx: Stream[{output_type}] = self.channel.{input_mode}_{output_mode}(  # type: ignore\n'
299            f"        '/{file.package}.{service.name}/{method.name}',\n"
300            f'        request_serializer={input_type_pb2}.SerializeToString,  # type: ignore\n'
301            f'        response_deserializer={output_type_pb2}.FromString  # type: ignore\n'
302            f'    )(tx)\n'
303            f'    return StreamStream(tx, rx)'
304        ).split('\n')
305    if input_mode == 'stream':
306        iterator_type = 'Iterator' if sync else 'AsyncIterator'
307        add_import(imports, f'from typing import {iterator_type}')
308        add_import(imports, f'from typing import Optional')
309        return (
310            f'def {method.name}(self, iterator: {iterator_type}[{input_type}], timeout: Optional[float] = None) -> {output_type_hint}:\n'
311            f'    return self.channel.{input_mode}_{output_mode}(  # type: ignore\n'
312            f"        '/{file.package}.{service.name}/{method.name}',\n"
313            f'        request_serializer={input_type_pb2}.SerializeToString,  # type: ignore\n'
314            f'        response_deserializer={output_type_pb2}.FromString  # type: ignore\n'
315            f'    )(iterator)'
316        ).split('\n')
317    else:
318        add_import(imports, f'from typing import Optional')
319        assert isinstance(input_msg, DescriptorProto)
320        input_fields, _ = collect_message(imports, input_msg, None)
321        args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in input_fields])
322        args_name = ', '.join([f'{name}={name}' for name, _, _ in input_fields])
323        if args: args = ', ' + args
324        return (
325            f'def {method.name}(self{args}, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = None) -> {output_type_hint}:\n'
326            f'    return self.channel.{input_mode}_{output_mode}(  # type: ignore\n'
327            f"        '/{file.package}.{service.name}/{method.name}',\n"
328            f'        request_serializer={input_type_pb2}.SerializeToString,  # type: ignore\n'
329            f'        response_deserializer={output_type_pb2}.FromString  # type: ignore\n'
330            f'    )({input_type_pb2}({args_name}), wait_for_ready=wait_for_ready, timeout=timeout)  # type: ignore'
331        ).split('\n')
332
333
334def generate_service(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]:
335    methods = '\n\n    '.join([
336        '\n    '.join(
337            generate_service_method(imports, file, service, method, sync)
338        ) for method in service.method
339    ])
340    channel_type = 'grpc.Channel' if sync else 'grpc.aio.Channel'
341    return (
342        f'class {service.name}:\n'
343        f'    channel: {channel_type}\n'
344        f'\n'
345        f'    def __init__(self, channel: {channel_type}) -> None:\n'
346        f'        self.channel = channel\n'
347        f'\n'
348        f'    {methods}\n'
349    ).split('\n')
350
351
352def generate_servicer_method(imports: List[str], method: MethodDescriptorProto, sync: bool = True) -> List[str]:
353    input_mode = 'stream' if method.client_streaming else 'unary'
354    output_mode = 'stream' if method.server_streaming else 'unary'
355
356    input_type, _, _ = import_type(imports, method.input_type, None)
357    output_type, _, _ = import_type(imports, method.output_type, None)
358
359    output_type_hint = output_type
360    if output_mode == 'stream':
361        if sync:
362            output_type_hint = f'Generator[{output_type}, None, None]'
363            add_import(imports, f'from typing import Generator')
364        else:
365            output_type_hint = f'AsyncGenerator[{output_type}, None]'
366            add_import(imports, f'from typing import AsyncGenerator')
367
368    iterator_type = 'Iterator' if sync else 'AsyncIterator'
369
370    if input_mode == 'stream':
371        iterator_type = 'Iterator' if sync else 'AsyncIterator'
372        add_import(imports, f'from typing import {iterator_type}')
373        lines = (('' if sync else 'async ') + (
374            f'def {method.name}(self, request: {iterator_type}[{input_type}], context: grpc.ServicerContext) -> {output_type_hint}:\n'
375            f'    context.set_code(grpc.StatusCode.UNIMPLEMENTED)  # type: ignore\n'
376            f'    context.set_details("Method not implemented!")  # type: ignore\n'
377            f'    raise NotImplementedError("Method not implemented!")'
378        )).split('\n')
379    else:
380        lines = (('' if sync else 'async ') + (
381            f'def {method.name}(self, request: {input_type}, context: grpc.ServicerContext) -> {output_type_hint}:\n'
382            f'    context.set_code(grpc.StatusCode.UNIMPLEMENTED)  # type: ignore\n'
383            f'    context.set_details("Method not implemented!")  # type: ignore\n'
384            f'    raise NotImplementedError("Method not implemented!")'
385        )).split('\n')
386    if output_mode == 'stream':
387        lines.append(f'    yield {output_type}()  # no-op: to make the linter happy')
388    return lines
389
390
391def generate_servicer(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]:
392    methods = '\n\n    '.join([
393        '\n    '.join(
394            generate_servicer_method(imports, method, sync)
395        ) for method in service.method
396    ])
397    if not methods:
398        methods = 'pass'
399    return (
400        f'class {service.name}Servicer:\n'
401        f'    {methods}\n'
402    ).split('\n')
403
404
405def generate_rpc_method_handler(imports: List[str], method: MethodDescriptorProto) -> List[str]:
406    input_mode = 'stream' if method.client_streaming else 'unary'
407    output_mode = 'stream' if method.server_streaming else 'unary'
408
409    input_type, _, _ = import_type(imports, method.input_type, None)
410    output_type, _, _ = import_type(imports, method.output_type, None)
411
412    return (
413        f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler(  # type: ignore\n"
414        f'        servicer.{method.name},\n'
415        f'        request_deserializer={input_type}.FromString,  # type: ignore\n'
416        f'        response_serializer={output_type}.SerializeToString,  # type: ignore\n'
417        f'    ),\n'
418    ).split('\n')
419
420
421def generate_add_servicer_to_server_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]:
422    method_handlers = '    '.join([
423        '\n    '.join(
424            generate_rpc_method_handler(imports, method)
425        ) for method in service.method
426    ])
427    server_type = 'grpc.Server' if sync else 'grpc.aio.Server'
428    return (
429        f'def add_{service.name}Servicer_to_server(servicer: {service.name}Servicer, server: {server_type}) -> None:\n'
430        f'    rpc_method_handlers = {{\n'
431        f'        {method_handlers}\n'
432        f'    }}\n'
433        f'    generic_handler = grpc.method_handlers_generic_handler(  # type: ignore\n'
434        f"        '{file.package}.{service.name}', rpc_method_handlers)\n"
435        f'    server.add_generic_rpc_handlers((generic_handler,))  # type: ignore\n'
436    ).split('\n')
437
438
439_HEADER = '''# Copyright 2022 Google LLC
440#
441# Licensed under the Apache License, Version 2.0 (the "License");
442# you may not use this file except in compliance with the License.
443# You may obtain a copy of the License at
444#
445#     https://www.apache.org/licenses/LICENSE-2.0
446#
447# Unless required by applicable law or agreed to in writing, software
448# distributed under the License is distributed on an "AS IS" BASIS,
449# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
450# See the License for the specific language governing permissions and
451# limitations under the License.
452
453"""Generated python gRPC interfaces."""
454'''
455
456_UTILS_PY = f'''{_HEADER}
457
458import asyncio
459import queue
460import grpc
461import sys
462
463from typing import Any, AsyncIterable, AsyncIterator, Generic, Iterator, TypeVar
464
465
466_T_co = TypeVar('_T_co', covariant=True)
467_T = TypeVar('_T')
468
469
470class Stream(Iterator[_T_co], grpc.RpcContext): ...
471
472
473class AioStream(AsyncIterable[_T_co], grpc.RpcContext): ...
474
475
476class Sender(Iterator[_T]):
477    if sys.version_info >= (3, 8):
478        _inner: queue.Queue[_T]
479    else:
480        _inner: queue.Queue
481
482    def __init__(self) -> None:
483        self._inner = queue.Queue()
484
485    def __iter__(self) -> Iterator[_T]:
486        return self
487
488    def __next__(self) -> _T:
489        return self._inner.get()
490
491    def send(self, item: _T) -> None:
492        self._inner.put(item)
493
494
495class AioSender(AsyncIterator[_T]):
496    if sys.version_info >= (3, 8):
497        _inner: asyncio.Queue[_T]
498    else:
499        _inner: asyncio.Queue
500
501    def __init__(self) -> None:
502        self._inner = asyncio.Queue()
503
504    def __iter__(self) -> AsyncIterator[_T]:
505        return self
506
507    async def __anext__(self) -> _T:
508        return await self._inner.get()
509
510    async def send(self, item: _T) -> None:
511        await self._inner.put(item)
512
513    def send_nowait(self, item: _T) -> None:
514        self._inner.put_nowait(item)
515
516
517class StreamStream(Generic[_T, _T_co], Iterator[_T_co], grpc.RpcContext):
518    _sender: Sender[_T]
519    _receiver: Stream[_T_co]
520
521    def __init__(self, sender: Sender[_T], receiver: Stream[_T_co]) -> None:
522        self._sender = sender
523        self._receiver = receiver
524
525    def send(self, item: _T) -> None:
526        self._sender.send(item)
527
528    def __iter__(self) -> Iterator[_T_co]:
529        return self._receiver.__iter__()
530
531    def __next__(self) -> _T_co:
532        return self._receiver.__next__()
533
534    def is_active(self) -> bool:
535        return self._receiver.is_active()  # type: ignore
536
537    def time_remaining(self) -> float:
538        return self._receiver.time_remaining()  # type: ignore
539
540    def cancel(self) -> None:
541        self._receiver.cancel()  # type: ignore
542
543    def add_callback(self, callback: Any) -> None:
544        self._receiver.add_callback(callback)  # type: ignore
545
546
547class AioStreamStream(Generic[_T, _T_co], AsyncIterator[_T_co], grpc.RpcContext):
548    _sender: AioSender[_T]
549    _receiver: AioStream[_T_co]
550
551    def __init__(self, sender: AioSender[_T], receiver: AioStream[_T_co]) -> None:
552        self._sender = sender
553        self._receiver = receiver
554
555    def __aiter__(self) -> AsyncIterator[_T_co]:
556        return self._receiver.__aiter__()
557
558    async def __anext__(self) -> _T_co:
559        return await self._receiver.__aiter__().__anext__()
560
561    async def send(self, item: _T) -> None:
562        await self._sender.send(item)
563
564    def send_nowait(self, item: _T) -> None:
565        self._sender.send_nowait(item)
566
567    def is_active(self) -> bool:
568        return self._receiver.is_active()  # type: ignore
569
570    def time_remaining(self) -> float:
571        return self._receiver.time_remaining()  # type: ignore
572
573    def cancel(self) -> None:
574        self._receiver.cancel()  # type: ignore
575
576    def add_callback(self, callback: Any) -> None:
577        self._receiver.add_callback(callback)  # type: ignore
578'''
579
580
581_FILES: List[CodeGeneratorResponse.File] = []
582_UTILS_FILES: Set[str] = set()
583
584
585for file_name in _REQUEST.file_to_generate:
586    file: FileDescriptorProto = next(filter(lambda x: x.name == file_name, _REQUEST.proto_file))
587
588    _FILES.append(CodeGeneratorResponse.File(
589        name=file.name.replace('.proto', '_pb2.py'),
590        insertion_point=f'module_scope',
591        content='def unwrap(x):\n  assert x\n  return x\n'
592    ))
593
594    pyi_imports: List[str] = []
595    grpc_imports: List[str] = ['import grpc']
596    grpc_aio_imports: List[str] = ['import grpc', 'import grpc.aio']
597
598    enums = '\n'.join(sum([generate_enum(pyi_imports, file, enum, _FILES) for enum in file.enum_type], []))
599    messages = '\n'.join(sum([generate_message(pyi_imports, file, message, _FILES) for message in file.message_type], []))
600
601    services = '\n'.join(sum([generate_service(grpc_imports, file, service) for service in file.service], []))
602    aio_services = '\n'.join(sum([generate_service(grpc_aio_imports, file, service, False) for service in file.service], []))
603
604    servicers = '\n'.join(sum([generate_servicer(grpc_imports, file, service) for service in file.service], []))
605    aio_servicers = '\n'.join(sum([generate_servicer(grpc_aio_imports, file, service, False) for service in file.service], []))
606
607    add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_imports, file, service) for service in file.service], []))
608    aio_add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_aio_imports, file, service, False) for service in file.service], []))
609
610    pyi_imports.sort()
611    grpc_imports.sort()
612    grpc_aio_imports.sort()
613
614    pyi_imports_str: str = '\n'.join(pyi_imports)
615    grpc_imports_str: str = '\n'.join(grpc_imports)
616    grpc_aio_imports_str: str = '\n'.join(grpc_aio_imports)
617
618    utils_filename = file_name.replace(os.path.basename(file_name), '_utils.py')
619    if utils_filename not in _UTILS_FILES:
620        _UTILS_FILES.add(utils_filename)
621        _FILES.extend([
622            CodeGeneratorResponse.File(
623                name=utils_filename,
624                content=_UTILS_PY,
625            )
626        ])
627
628    _FILES.extend([
629        CodeGeneratorResponse.File(
630            name=file.name.replace('.proto', '_pb2.pyi'),
631            content=f'{_HEADER}\n\n{pyi_imports_str}\n\n{enums}\n\n{messages}\n'
632        ),
633        CodeGeneratorResponse.File(
634            name=file_name.replace('.proto', '_grpc.py'),
635            content=f'{_HEADER}\n\n{grpc_imports_str}\n\n{services}\n\n{servicers}\n\n{add_servicer_methods}'
636        ),
637        CodeGeneratorResponse.File(
638            name=file_name.replace('.proto', '_grpc_aio.py'),
639            content=f'{_HEADER}\n\n{grpc_aio_imports_str}\n\n{aio_services}\n\n{aio_servicers}\n\n{aio_add_servicer_methods}'
640        )
641    ])
642
643
644sys.stdout.buffer.write(CodeGeneratorResponse(file=_FILES).SerializeToString())
645