1#!/usr/bin/env python3 2 3# Copyright 2022 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Custom mmi2grpc gRPC compiler.""" 18 19import sys 20 21from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, \ 22 CodeGeneratorResponse 23 24 25def eprint(*args, **kwargs): 26 print(*args, file=sys.stderr, **kwargs) 27 28 29request = CodeGeneratorRequest.FromString(sys.stdin.buffer.read()) 30 31 32def has_type(proto_file, type_name): 33 return any(filter(lambda x: x.name == type_name, proto_file.message_type)) 34 35 36def import_type(imports, type): 37 package = type[1:type.rindex('.')] 38 type_name = type[type.rindex('.')+1:] 39 file = next(filter( 40 lambda x: x.package == package and has_type(x, type_name), 41 request.proto_file)) 42 python_path = file.name.replace('.proto', '').replace('/', '.') 43 as_name = python_path.replace('.', '_dot_') + '__pb2' 44 module_path = python_path[:python_path.rindex('.')] 45 module_name = python_path[python_path.rindex('.')+1:] + '_pb2' 46 imports.add(f'from {module_path} import {module_name} as {as_name}') 47 return f'{as_name}.{type_name}' 48 49 50def generate_method(imports, file, service, method): 51 input_mode = 'stream' if method.client_streaming else 'unary' 52 output_mode = 'stream' if method.server_streaming else 'unary' 53 54 input_type = import_type(imports, method.input_type) 55 output_type = import_type(imports, method.output_type) 56 57 if input_mode == 'stream': 58 if output_mode == 'stream': 59 return ( 60 f'def {method.name}(self):\n' 61 f' from mmi2grpc._streaming import StreamWrapper\n' 62 f' return StreamWrapper(\n' 63 f' self.channel.{input_mode}_{output_mode}(\n' 64 f" '/{file.package}.{service.name}/{method.name}',\n" 65 f' request_serializer={input_type}.SerializeToString,\n' 66 f' response_deserializer={output_type}.FromString\n' 67 f' ),\n' 68 f' {input_type})' 69 ).split('\n') 70 else: 71 return ( 72 f'def {method.name}(self, iterator, **kwargs):\n' 73 f' return self.channel.{input_mode}_{output_mode}(\n' 74 f" '/{file.package}.{service.name}/{method.name}',\n" 75 f' request_serializer={input_type}.SerializeToString,\n' 76 f' response_deserializer={output_type}.FromString\n' 77 f' )(iterator, **kwargs)' 78 ).split('\n') 79 else: 80 return ( 81 f'def {method.name}(self, wait_for_ready=None, **kwargs):\n' 82 f' return self.channel.{input_mode}_{output_mode}(\n' 83 f" '/{file.package}.{service.name}/{method.name}',\n" 84 f' request_serializer={input_type}.SerializeToString,\n' 85 f' response_deserializer={output_type}.FromString\n' 86 f' )({input_type}(**kwargs), wait_for_ready=wait_for_ready)' 87 ).split('\n') 88 89 90def generate_service(imports, file, service): 91 methods = '\n\n '.join([ 92 '\n '.join( 93 generate_method(imports, file, service, method) 94 ) for method in service.method 95 ]) 96 return ( 97 f'class {service.name}:\n' 98 f' def __init__(self, channel):\n' 99 f' self.channel = channel\n' 100 f'\n' 101 f' {methods}\n' 102 ).split('\n') 103 104def generate_servicer_method(method): 105 input_mode = 'stream' if method.client_streaming else 'unary' 106 107 if input_mode == 'stream': 108 return ( 109 f'def {method.name}(self, request_iterator, context):\n' 110 f' context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n' 111 f' context.set_details("Method not implemented!")\n' 112 f' raise NotImplementedError("Method not implemented!")' 113 ).split('\n') 114 else: 115 return ( 116 f'def {method.name}(self, request, context):\n' 117 f' context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n' 118 f' context.set_details("Method not implemented!")\n' 119 f' raise NotImplementedError("Method not implemented!")' 120 ).split('\n') 121 122 123def generate_servicer(service): 124 methods = '\n\n '.join([ 125 '\n '.join( 126 generate_servicer_method(method) 127 ) for method in service.method 128 ]) 129 if len(methods) == 0: 130 methods = 'pass' 131 return ( 132 f'class {service.name}Servicer:\n' 133 f'\n' 134 f' {methods}\n' 135 ).split('\n') 136 137def generate_rpc_method_handler(imports, method): 138 input_mode = 'stream' if method.client_streaming else 'unary' 139 output_mode = 'stream' if method.server_streaming else 'unary' 140 141 input_type = import_type(imports, method.input_type) 142 output_type = import_type(imports, method.output_type) 143 144 return ( 145 f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler(\n" 146 f' servicer.{method.name},\n' 147 f' request_deserializer={input_type}.FromString,\n' 148 f' response_serializer={output_type}.SerializeToString,\n' 149 f' ),\n' 150 ).split('\n') 151 152def generate_add_servicer_to_server_method(imports, file, service): 153 method_handlers = ' '.join([ 154 '\n '.join( 155 generate_rpc_method_handler(imports, method) 156 ) for method in service.method 157 ]) 158 return ( 159 f'def add_{service.name}Servicer_to_server(servicer, server):\n' 160 f' rpc_method_handlers = {{\n' 161 f' {method_handlers}\n' 162 f' }}\n' 163 f' generic_handler = grpc.method_handlers_generic_handler(\n' 164 f" '{file.package}.{service.name}', rpc_method_handlers)\n" 165 f' server.add_generic_rpc_handlers((generic_handler,))' 166 ).split('\n') 167 168files = [] 169 170for file_name in request.file_to_generate: 171 file = next(filter(lambda x: x.name == file_name, request.proto_file)) 172 173 imports = set(['import grpc']) 174 175 services = '\n'.join(sum([ 176 generate_service(imports, file, service) for service in file.service 177 ], [])) 178 179 servicers = '\n'.join(sum([ 180 generate_servicer(service) for service in file.service 181 ], [])) 182 183 add_servicer_methods = '\n'.join(sum([ 184 generate_add_servicer_to_server_method(imports, file, service) for service in file.service 185 ], [])) 186 187 files.append(CodeGeneratorResponse.File( 188 name=file_name.replace('.proto', '_grpc.py'), 189 content='\n'.join(imports) + '\n\n' + services + '\n\n' + servicers + '\n\n' + add_servicer_methods + '\n' 190 )) 191 192reponse = CodeGeneratorResponse(file=files) 193 194sys.stdout.buffer.write(reponse.SerializeToString()) 195