• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2022 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Custom mmi2grpc gRPC compiler."""
18
19import sys
20
21from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, \
22    CodeGeneratorResponse
23
24
25def eprint(*args, **kwargs):
26    print(*args, file=sys.stderr, **kwargs)
27
28
29request = CodeGeneratorRequest.FromString(sys.stdin.buffer.read())
30
31
32def has_type(proto_file, type_name):
33    return any(filter(lambda x: x.name == type_name, proto_file.message_type))
34
35
36def import_type(imports, type):
37    package = type[1:type.rindex('.')]
38    type_name = type[type.rindex('.')+1:]
39    file = next(filter(
40        lambda x: x.package == package and has_type(x, type_name),
41        request.proto_file))
42    python_path = file.name.replace('.proto', '').replace('/', '.')
43    as_name = python_path.replace('.', '_dot_') + '__pb2'
44    module_path = python_path[:python_path.rindex('.')]
45    module_name = python_path[python_path.rindex('.')+1:] + '_pb2'
46    imports.add(f'from {module_path} import {module_name} as {as_name}')
47    return f'{as_name}.{type_name}'
48
49
50def generate_service_method(imports, file, service, method):
51    input_mode = 'stream' if method.client_streaming else 'unary'
52    output_mode = 'stream' if method.server_streaming else 'unary'
53
54    input_type = import_type(imports, method.input_type)
55    output_type = import_type(imports, method.output_type)
56
57    if input_mode == 'stream':
58        return (
59            f'def {method.name}(self, iterator, **kwargs):\n'
60            f'    return self.channel.{input_mode}_{output_mode}(\n'
61            f"        '/{file.package}.{service.name}/{method.name}',\n"
62            f'        request_serializer={input_type}.SerializeToString,\n'
63            f'        response_deserializer={output_type}.FromString\n'
64            f'    )(iterator, **kwargs)'
65        ).split('\n')
66    else:
67        return (
68            f'def {method.name}(self, wait_for_ready=None, **kwargs):\n'
69            f'    return self.channel.{input_mode}_{output_mode}(\n'
70            f"        '/{file.package}.{service.name}/{method.name}',\n"
71            f'        request_serializer={input_type}.SerializeToString,\n'
72            f'        response_deserializer={output_type}.FromString\n'
73            f'    )({input_type}(**kwargs), wait_for_ready=wait_for_ready)'
74        ).split('\n')
75
76
77def generate_service(imports, file, service):
78    methods = '\n\n    '.join([
79        '\n    '.join(
80            generate_service_method(imports, file, service, method)
81        ) for method in service.method
82    ])
83    return (
84        f'class {service.name}:\n'
85        f'    def __init__(self, channel):\n'
86        f'        self.channel = channel\n'
87        f'\n'
88        f'    {methods}\n'
89    ).split('\n')
90
91
92def generate_servicer_method(method):
93    input_mode = 'stream' if method.client_streaming else 'unary'
94
95    if input_mode == 'stream':
96        return (
97            f'def {method.name}(self, request_iterator, context):\n'
98            f'    context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n'
99            f'    context.set_details("Method not implemented!")\n'
100            f'    raise NotImplementedError("Method not implemented!")'
101        ).split('\n')
102    else:
103        return (
104            f'def {method.name}(self, request, context):\n'
105            f'    context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n'
106            f'    context.set_details("Method not implemented!")\n'
107            f'    raise NotImplementedError("Method not implemented!")'
108        ).split('\n')
109
110
111def generate_servicer(service):
112    methods = '\n\n    '.join([
113        '\n    '.join(
114            generate_servicer_method(method)
115        ) for method in service.method
116    ])
117    return (
118        f'class {service.name}Servicer:\n'
119        f'\n'
120        f'    {methods}\n'
121    ).split('\n')
122
123
124def generate_rpc_method_handler(imports, method):
125    input_mode = 'stream' if method.client_streaming else 'unary'
126    output_mode = 'stream' if method.server_streaming else 'unary'
127
128    input_type = import_type(imports, method.input_type)
129    output_type = import_type(imports, method.output_type)
130
131    return (
132        f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler(\n"
133        f'        servicer.{method.name},\n'
134        f'        request_deserializer={input_type}.FromString,\n'
135        f'        response_serializer={output_type}.SerializeToString,\n'
136        f'    ),\n'
137    ).split('\n')
138
139
140def generate_add_servicer_to_server_method(imports, file, service):
141    method_handlers = '    '.join([
142        '\n    '.join(
143            generate_rpc_method_handler(imports, method)
144        ) for method in service.method
145    ])
146    return (
147        f'def add_{service.name}Servicer_to_server(servicer, server):\n'
148        f'    rpc_method_handlers = {{\n'
149        f'        {method_handlers}\n'
150        f'    }}\n'
151        f'    generic_handler = grpc.method_handlers_generic_handler(\n'
152        f"        '{file.package}.{service.name}', rpc_method_handlers)\n"
153        f'    server.add_generic_rpc_handlers((generic_handler,))'
154    ).split('\n')
155
156
157files = []
158
159for file_name in request.file_to_generate:
160    file = next(filter(lambda x: x.name == file_name, request.proto_file))
161
162    imports = set(['import grpc'])
163
164    services = '\n'.join(sum([
165        generate_service(imports, file, service) for service in file.service
166    ], []))
167
168    servicers = '\n'.join(sum([
169        generate_servicer(service) for service in file.service
170    ], []))
171
172    add_servicer_methods = '\n'.join(sum([
173        generate_add_servicer_to_server_method(imports, file, service) for service in file.service
174    ], []))
175
176    files.append(CodeGeneratorResponse.File(
177        name=file_name.replace('.proto', '_grpc.py'),
178        content='\n'.join(imports) + '\n\n' + services  + '\n\n' + servicers + '\n\n' + add_servicer_methods + '\n'
179    ))
180
181response = CodeGeneratorResponse(file=files)
182
183sys.stdout.buffer.write(response.SerializeToString())
184