• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2022 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Custom mmi2grpc gRPC compiler."""
18
19import sys
20
21from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, \
22    CodeGeneratorResponse
23
24
25def eprint(*args, **kwargs):
26    print(*args, file=sys.stderr, **kwargs)
27
28
29request = CodeGeneratorRequest.FromString(sys.stdin.buffer.read())
30
31
32def has_type(proto_file, type_name):
33    return any(filter(lambda x: x.name == type_name, proto_file.message_type))
34
35
36def import_type(imports, type):
37    package = type[1:type.rindex('.')]
38    type_name = type[type.rindex('.')+1:]
39    file = next(filter(
40        lambda x: x.package == package and has_type(x, type_name),
41        request.proto_file))
42    python_path = file.name.replace('.proto', '').replace('/', '.')
43    as_name = python_path.replace('.', '_dot_') + '__pb2'
44    module_path = python_path[:python_path.rindex('.')]
45    module_name = python_path[python_path.rindex('.')+1:] + '_pb2'
46    imports.add(f'from {module_path} import {module_name} as {as_name}')
47    return f'{as_name}.{type_name}'
48
49
50def generate_method(imports, file, service, method):
51    input_mode = 'stream' if method.client_streaming else 'unary'
52    output_mode = 'stream' if method.server_streaming else 'unary'
53
54    input_type = import_type(imports, method.input_type)
55    output_type = import_type(imports, method.output_type)
56
57    if input_mode == 'stream':
58        if output_mode == 'stream':
59            return (
60                f'def {method.name}(self):\n'
61                f'    from mmi2grpc._streaming import StreamWrapper\n'
62                f'    return StreamWrapper(\n'
63                f'        self.channel.{input_mode}_{output_mode}(\n'
64                f"            '/{file.package}.{service.name}/{method.name}',\n"
65                f'            request_serializer={input_type}.SerializeToString,\n'
66                f'            response_deserializer={output_type}.FromString\n'
67                f'        ),\n'
68                f'        {input_type})'
69            ).split('\n')
70        else:
71            return (
72                f'def {method.name}(self, iterator, **kwargs):\n'
73                f'    return self.channel.{input_mode}_{output_mode}(\n'
74                f"        '/{file.package}.{service.name}/{method.name}',\n"
75                f'        request_serializer={input_type}.SerializeToString,\n'
76                f'        response_deserializer={output_type}.FromString\n'
77                f'    )(iterator, **kwargs)'
78            ).split('\n')
79    else:
80        return (
81            f'def {method.name}(self, wait_for_ready=None, **kwargs):\n'
82            f'    return self.channel.{input_mode}_{output_mode}(\n'
83            f"        '/{file.package}.{service.name}/{method.name}',\n"
84            f'        request_serializer={input_type}.SerializeToString,\n'
85            f'        response_deserializer={output_type}.FromString\n'
86            f'    )({input_type}(**kwargs), wait_for_ready=wait_for_ready)'
87        ).split('\n')
88
89
90def generate_service(imports, file, service):
91    methods = '\n\n    '.join([
92        '\n    '.join(
93            generate_method(imports, file, service, method)
94        ) for method in service.method
95    ])
96    return (
97        f'class {service.name}:\n'
98        f'    def __init__(self, channel):\n'
99        f'        self.channel = channel\n'
100        f'\n'
101        f'    {methods}\n'
102    ).split('\n')
103
104def generate_servicer_method(method):
105    input_mode = 'stream' if method.client_streaming else 'unary'
106
107    if input_mode == 'stream':
108        return (
109            f'def {method.name}(self, request_iterator, context):\n'
110            f'    context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n'
111            f'    context.set_details("Method not implemented!")\n'
112            f'    raise NotImplementedError("Method not implemented!")'
113        ).split('\n')
114    else:
115        return (
116            f'def {method.name}(self, request, context):\n'
117            f'    context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n'
118            f'    context.set_details("Method not implemented!")\n'
119            f'    raise NotImplementedError("Method not implemented!")'
120        ).split('\n')
121
122
123def generate_servicer(service):
124    methods = '\n\n    '.join([
125        '\n    '.join(
126            generate_servicer_method(method)
127        ) for method in service.method
128    ])
129    if len(methods) == 0:
130        methods = 'pass'
131    return (
132        f'class {service.name}Servicer:\n'
133        f'\n'
134        f'    {methods}\n'
135    ).split('\n')
136
137def generate_rpc_method_handler(imports, method):
138    input_mode = 'stream' if method.client_streaming else 'unary'
139    output_mode = 'stream' if method.server_streaming else 'unary'
140
141    input_type = import_type(imports, method.input_type)
142    output_type = import_type(imports, method.output_type)
143
144    return (
145        f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler(\n"
146        f'        servicer.{method.name},\n'
147        f'        request_deserializer={input_type}.FromString,\n'
148        f'        response_serializer={output_type}.SerializeToString,\n'
149        f'    ),\n'
150    ).split('\n')
151
152def generate_add_servicer_to_server_method(imports, file, service):
153    method_handlers = '    '.join([
154        '\n    '.join(
155            generate_rpc_method_handler(imports, method)
156        ) for method in service.method
157    ])
158    return (
159        f'def add_{service.name}Servicer_to_server(servicer, server):\n'
160        f'    rpc_method_handlers = {{\n'
161        f'        {method_handlers}\n'
162        f'    }}\n'
163        f'    generic_handler = grpc.method_handlers_generic_handler(\n'
164        f"        '{file.package}.{service.name}', rpc_method_handlers)\n"
165        f'    server.add_generic_rpc_handlers((generic_handler,))'
166    ).split('\n')
167
168files = []
169
170for file_name in request.file_to_generate:
171    file = next(filter(lambda x: x.name == file_name, request.proto_file))
172
173    imports = set(['import grpc'])
174
175    services = '\n'.join(sum([
176        generate_service(imports, file, service) for service in file.service
177    ], []))
178
179    servicers = '\n'.join(sum([
180        generate_servicer(service) for service in file.service
181    ], []))
182
183    add_servicer_methods = '\n'.join(sum([
184        generate_add_servicer_to_server_method(imports, file, service) for service in file.service
185    ], []))
186
187    files.append(CodeGeneratorResponse.File(
188        name=file_name.replace('.proto', '_grpc.py'),
189        content='\n'.join(imports) + '\n\n' + services  + '\n\n' + servicers + '\n\n' + add_servicer_methods + '\n'
190    ))
191
192reponse = CodeGeneratorResponse(file=files)
193
194sys.stdout.buffer.write(reponse.SerializeToString())
195