• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Base implementation of reflection servicer."""
15
16import grpc
17from google.protobuf import descriptor_pb2
18from google.protobuf import descriptor_pool
19
20from grpc_reflection.v1alpha import reflection_pb2 as _reflection_pb2
21from grpc_reflection.v1alpha import reflection_pb2_grpc as _reflection_pb2_grpc
22
23_POOL = descriptor_pool.Default()
24
25
26def _not_found_error():
27    return _reflection_pb2.ServerReflectionResponse(
28        error_response=_reflection_pb2.ErrorResponse(
29            error_code=grpc.StatusCode.NOT_FOUND.value[0],
30            error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
31        ))
32
33
34def _file_descriptor_response(descriptor):
35    proto = descriptor_pb2.FileDescriptorProto()
36    descriptor.CopyToProto(proto)
37    serialized_proto = proto.SerializeToString()
38    return _reflection_pb2.ServerReflectionResponse(
39        file_descriptor_response=_reflection_pb2.FileDescriptorResponse(
40            file_descriptor_proto=(serialized_proto,)),)
41
42
43class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer):
44    """Base class for reflection servicer."""
45
46    def __init__(self, service_names, pool=None):
47        """Constructor.
48
49        Args:
50            service_names: Iterable of fully-qualified service names available.
51            pool: An optional DescriptorPool instance.
52        """
53        self._service_names = tuple(sorted(service_names))
54        self._pool = _POOL if pool is None else pool
55
56    def _file_by_filename(self, filename):
57        try:
58            descriptor = self._pool.FindFileByName(filename)
59        except KeyError:
60            return _not_found_error()
61        else:
62            return _file_descriptor_response(descriptor)
63
64    def _file_containing_symbol(self, fully_qualified_name):
65        try:
66            descriptor = self._pool.FindFileContainingSymbol(
67                fully_qualified_name)
68        except KeyError:
69            return _not_found_error()
70        else:
71            return _file_descriptor_response(descriptor)
72
73    def _file_containing_extension(self, containing_type, extension_number):
74        try:
75            message_descriptor = self._pool.FindMessageTypeByName(
76                containing_type)
77            extension_descriptor = self._pool.FindExtensionByNumber(
78                message_descriptor, extension_number)
79            descriptor = self._pool.FindFileContainingSymbol(
80                extension_descriptor.full_name)
81        except KeyError:
82            return _not_found_error()
83        else:
84            return _file_descriptor_response(descriptor)
85
86    def _all_extension_numbers_of_type(self, containing_type):
87        try:
88            message_descriptor = self._pool.FindMessageTypeByName(
89                containing_type)
90            extension_numbers = tuple(
91                sorted(extension.number for extension in
92                       self._pool.FindAllExtensions(message_descriptor)))
93        except KeyError:
94            return _not_found_error()
95        else:
96            return _reflection_pb2.ServerReflectionResponse(
97                all_extension_numbers_response=_reflection_pb2.
98                ExtensionNumberResponse(
99                    base_type_name=message_descriptor.full_name,
100                    extension_number=extension_numbers))
101
102    def _list_services(self):
103        return _reflection_pb2.ServerReflectionResponse(
104            list_services_response=_reflection_pb2.ListServiceResponse(service=[
105                _reflection_pb2.ServiceResponse(name=service_name)
106                for service_name in self._service_names
107            ]))
108
109
110__all__ = ['BaseReflectionServicer']
111