1#!/usr/bin/env python3 2 3# Copyright 2022 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Custom mmi2grpc gRPC compiler.""" 18 19import sys 20 21from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, \ 22 CodeGeneratorResponse 23 24 25def eprint(*args, **kwargs): 26 print(*args, file=sys.stderr, **kwargs) 27 28 29request = CodeGeneratorRequest.FromString(sys.stdin.buffer.read()) 30 31 32def has_type(proto_file, type_name): 33 return any(filter(lambda x: x.name == type_name, proto_file.message_type)) 34 35 36def import_type(imports, type): 37 package = type[1:type.rindex('.')] 38 type_name = type[type.rindex('.')+1:] 39 file = next(filter( 40 lambda x: x.package == package and has_type(x, type_name), 41 request.proto_file)) 42 python_path = file.name.replace('.proto', '').replace('/', '.') 43 as_name = python_path.replace('.', '_dot_') + '__pb2' 44 module_path = python_path[:python_path.rindex('.')] 45 module_name = python_path[python_path.rindex('.')+1:] + '_pb2' 46 imports.add(f'from {module_path} import {module_name} as {as_name}') 47 return f'{as_name}.{type_name}' 48 49 50def generate_service_method(imports, file, service, method): 51 input_mode = 'stream' if method.client_streaming else 'unary' 52 output_mode = 'stream' if method.server_streaming else 'unary' 53 54 input_type = import_type(imports, method.input_type) 55 output_type = import_type(imports, method.output_type) 56 57 if input_mode == 'stream': 58 return ( 59 f'def {method.name}(self, iterator, **kwargs):\n' 60 f' return self.channel.{input_mode}_{output_mode}(\n' 61 f" '/{file.package}.{service.name}/{method.name}',\n" 62 f' request_serializer={input_type}.SerializeToString,\n' 63 f' response_deserializer={output_type}.FromString\n' 64 f' )(iterator, **kwargs)' 65 ).split('\n') 66 else: 67 return ( 68 f'def {method.name}(self, wait_for_ready=None, **kwargs):\n' 69 f' return self.channel.{input_mode}_{output_mode}(\n' 70 f" '/{file.package}.{service.name}/{method.name}',\n" 71 f' request_serializer={input_type}.SerializeToString,\n' 72 f' response_deserializer={output_type}.FromString\n' 73 f' )({input_type}(**kwargs), wait_for_ready=wait_for_ready)' 74 ).split('\n') 75 76 77def generate_service(imports, file, service): 78 methods = '\n\n '.join([ 79 '\n '.join( 80 generate_service_method(imports, file, service, method) 81 ) for method in service.method 82 ]) 83 return ( 84 f'class {service.name}:\n' 85 f' def __init__(self, channel):\n' 86 f' self.channel = channel\n' 87 f'\n' 88 f' {methods}\n' 89 ).split('\n') 90 91 92def generate_servicer_method(method): 93 input_mode = 'stream' if method.client_streaming else 'unary' 94 95 if input_mode == 'stream': 96 return ( 97 f'def {method.name}(self, request_iterator, context):\n' 98 f' context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n' 99 f' context.set_details("Method not implemented!")\n' 100 f' raise NotImplementedError("Method not implemented!")' 101 ).split('\n') 102 else: 103 return ( 104 f'def {method.name}(self, request, context):\n' 105 f' context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n' 106 f' context.set_details("Method not implemented!")\n' 107 f' raise NotImplementedError("Method not implemented!")' 108 ).split('\n') 109 110 111def generate_servicer(service): 112 methods = '\n\n '.join([ 113 '\n '.join( 114 generate_servicer_method(method) 115 ) for method in service.method 116 ]) 117 return ( 118 f'class {service.name}Servicer:\n' 119 f'\n' 120 f' {methods}\n' 121 ).split('\n') 122 123 124def generate_rpc_method_handler(imports, method): 125 input_mode = 'stream' if method.client_streaming else 'unary' 126 output_mode = 'stream' if method.server_streaming else 'unary' 127 128 input_type = import_type(imports, method.input_type) 129 output_type = import_type(imports, method.output_type) 130 131 return ( 132 f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler(\n" 133 f' servicer.{method.name},\n' 134 f' request_deserializer={input_type}.FromString,\n' 135 f' response_serializer={output_type}.SerializeToString,\n' 136 f' ),\n' 137 ).split('\n') 138 139 140def generate_add_servicer_to_server_method(imports, file, service): 141 method_handlers = ' '.join([ 142 '\n '.join( 143 generate_rpc_method_handler(imports, method) 144 ) for method in service.method 145 ]) 146 return ( 147 f'def add_{service.name}Servicer_to_server(servicer, server):\n' 148 f' rpc_method_handlers = {{\n' 149 f' {method_handlers}\n' 150 f' }}\n' 151 f' generic_handler = grpc.method_handlers_generic_handler(\n' 152 f" '{file.package}.{service.name}', rpc_method_handlers)\n" 153 f' server.add_generic_rpc_handlers((generic_handler,))' 154 ).split('\n') 155 156 157files = [] 158 159for file_name in request.file_to_generate: 160 file = next(filter(lambda x: x.name == file_name, request.proto_file)) 161 162 imports = set(['import grpc']) 163 164 services = '\n'.join(sum([ 165 generate_service(imports, file, service) for service in file.service 166 ], [])) 167 168 servicers = '\n'.join(sum([ 169 generate_servicer(service) for service in file.service 170 ], [])) 171 172 add_servicer_methods = '\n'.join(sum([ 173 generate_add_servicer_to_server_method(imports, file, service) for service in file.service 174 ], [])) 175 176 files.append(CodeGeneratorResponse.File( 177 name=file_name.replace('.proto', '_grpc.py'), 178 content='\n'.join(imports) + '\n\n' + services + '\n\n' + servicers + '\n\n' + add_servicer_methods + '\n' 179 )) 180 181response = CodeGeneratorResponse(file=files) 182 183sys.stdout.buffer.write(response.SerializeToString()) 184