• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15import re
16from typing import ClassVar, Dict, Optional
17
18# Workaround: `grpc` must be imported before `google.protobuf.json_format`,
19# to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897
20import grpc
21from google.protobuf import json_format
22import google.protobuf.message
23
24logger = logging.getLogger(__name__)
25
26# Type aliases
27Message = google.protobuf.message.Message
28
29
30class GrpcClientHelper:
31    channel: grpc.Channel
32    DEFAULT_RPC_DEADLINE_SEC = 90
33
34    def __init__(self, channel: grpc.Channel, stub_class: ClassVar):
35        self.channel = channel
36        self.stub = stub_class(channel)
37        # This is purely cosmetic to make RPC logs look like method calls.
38        self.log_service_name = re.sub('Stub$', '',
39                                       self.stub.__class__.__name__)
40
41    def call_unary_with_deadline(
42            self,
43            *,
44            rpc: str,
45            req: Message,
46            deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC,
47            log_level: Optional[int] = logging.DEBUG) -> Message:
48        if deadline_sec is None:
49            deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC
50
51        call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec)
52        self._log_rpc_request(rpc, req, call_kwargs, log_level)
53
54        # Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options)
55        rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
56        return rpc_callable(req, **call_kwargs)
57
58    def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG):
59        logger.log(logging.DEBUG if log_level is None else log_level,
60                   'RPC %s.%s(request=%s(%r), %s)', self.log_service_name, rpc,
61                   req.__class__.__name__, json_format.MessageToDict(req),
62                   ', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))
63
64
65class GrpcApp:
66    channels: Dict[int, grpc.Channel]
67
68    class NotFound(Exception):
69        """Requested resource not found"""
70
71        def __init__(self, message):
72            self.message = message
73            super().__init__(message)
74
75    def __init__(self, rpc_host):
76        self.rpc_host = rpc_host
77        # Cache gRPC channels per port
78        self.channels = dict()
79
80    def _make_channel(self, port) -> grpc.Channel:
81        if port not in self.channels:
82            target = f'{self.rpc_host}:{port}'
83            self.channels[port] = grpc.insecure_channel(target)
84        return self.channels[port]
85
86    def close(self):
87        # Close all channels
88        for channel in self.channels.values():
89            channel.close()
90
91    def __enter__(self):
92        return self
93
94    def __exit__(self, exc_type, exc_val, exc_tb):
95        self.close()
96        return False
97
98    def __del__(self):
99        self.close()
100