1# Copyright 2020 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import logging 15import re 16from typing import ClassVar, Dict, Optional 17 18# Workaround: `grpc` must be imported before `google.protobuf.json_format`, 19# to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897 20import grpc 21from google.protobuf import json_format 22import google.protobuf.message 23 24logger = logging.getLogger(__name__) 25 26# Type aliases 27Message = google.protobuf.message.Message 28 29 30class GrpcClientHelper: 31 channel: grpc.Channel 32 DEFAULT_RPC_DEADLINE_SEC = 90 33 34 def __init__(self, channel: grpc.Channel, stub_class: ClassVar): 35 self.channel = channel 36 self.stub = stub_class(channel) 37 # This is purely cosmetic to make RPC logs look like method calls. 38 self.log_service_name = re.sub('Stub$', '', 39 self.stub.__class__.__name__) 40 41 def call_unary_with_deadline( 42 self, 43 *, 44 rpc: str, 45 req: Message, 46 deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC, 47 log_level: Optional[int] = logging.DEBUG) -> Message: 48 if deadline_sec is None: 49 deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC 50 51 call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec) 52 self._log_rpc_request(rpc, req, call_kwargs, log_level) 53 54 # Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options) 55 rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc) 56 return rpc_callable(req, **call_kwargs) 57 58 def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG): 59 logger.log(logging.DEBUG if log_level is None else log_level, 60 'RPC %s.%s(request=%s(%r), %s)', self.log_service_name, rpc, 61 req.__class__.__name__, json_format.MessageToDict(req), 62 ', '.join({f'{k}={v}' for k, v in call_kwargs.items()})) 63 64 65class GrpcApp: 66 channels: Dict[int, grpc.Channel] 67 68 class NotFound(Exception): 69 """Requested resource not found""" 70 71 def __init__(self, message): 72 self.message = message 73 super().__init__(message) 74 75 def __init__(self, rpc_host): 76 self.rpc_host = rpc_host 77 # Cache gRPC channels per port 78 self.channels = dict() 79 80 def _make_channel(self, port) -> grpc.Channel: 81 if port not in self.channels: 82 target = f'{self.rpc_host}:{port}' 83 self.channels[port] = grpc.insecure_channel(target) 84 return self.channels[port] 85 86 def close(self): 87 # Close all channels 88 for channel in self.channels.values(): 89 channel.close() 90 91 def __enter__(self): 92 return self 93 94 def __exit__(self, exc_type, exc_val, exc_tb): 95 self.close() 96 return False 97 98 def __del__(self): 99 self.close() 100