• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Base implementation of reflection servicer."""
15
16from google.protobuf import descriptor_pb2
17from google.protobuf import descriptor_pool
18import grpc
19from grpc_reflection.v1alpha import reflection_pb2 as _reflection_pb2
20from grpc_reflection.v1alpha import reflection_pb2_grpc as _reflection_pb2_grpc
21
22_POOL = descriptor_pool.Default()
23
24
25def _not_found_error(original_request):
26    return _reflection_pb2.ServerReflectionResponse(
27        error_response=_reflection_pb2.ErrorResponse(
28            error_code=grpc.StatusCode.NOT_FOUND.value[0],
29            error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
30        ),
31        original_request=original_request,
32    )
33
34
35def _collect_transitive_dependencies(descriptor, seen_files):
36    seen_files.update({descriptor.name: descriptor})
37    for dependency in descriptor.dependencies:
38        if not dependency.name in seen_files:
39            # descriptors cannot have circular dependencies
40            _collect_transitive_dependencies(dependency, seen_files)
41
42
43def _file_descriptor_response(descriptor, original_request):
44    # collect all dependencies
45    descriptors = {}
46    _collect_transitive_dependencies(descriptor, descriptors)
47
48    # serialize all descriptors
49    serialized_proto_list = []
50    for d_key in descriptors:
51        proto = descriptor_pb2.FileDescriptorProto()
52        descriptors[d_key].CopyToProto(proto)
53        serialized_proto_list.append(proto.SerializeToString())
54
55    return _reflection_pb2.ServerReflectionResponse(
56        file_descriptor_response=_reflection_pb2.FileDescriptorResponse(
57            file_descriptor_proto=(serialized_proto_list)
58        ),
59        original_request=original_request,
60    )
61
62
63class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer):
64    """Base class for reflection servicer."""
65
66    def __init__(self, service_names, pool=None):
67        """Constructor.
68
69        Args:
70            service_names: Iterable of fully-qualified service names available.
71            pool: An optional DescriptorPool instance.
72        """
73        self._service_names = tuple(sorted(service_names))
74        self._pool = _POOL if pool is None else pool
75
76    def _file_by_filename(self, request, filename):
77        try:
78            descriptor = self._pool.FindFileByName(filename)
79        except KeyError:
80            return _not_found_error(request)
81        else:
82            return _file_descriptor_response(descriptor, request)
83
84    def _file_containing_symbol(self, request, fully_qualified_name):
85        try:
86            descriptor = self._pool.FindFileContainingSymbol(
87                fully_qualified_name
88            )
89        except KeyError:
90            return _not_found_error(request)
91        else:
92            return _file_descriptor_response(descriptor, request)
93
94    def _file_containing_extension(
95        self, request, containing_type, extension_number
96    ):
97        try:
98            message_descriptor = self._pool.FindMessageTypeByName(
99                containing_type
100            )
101            extension_descriptor = self._pool.FindExtensionByNumber(
102                message_descriptor, extension_number
103            )
104            descriptor = self._pool.FindFileContainingSymbol(
105                extension_descriptor.full_name
106            )
107        except KeyError:
108            return _not_found_error(request)
109        else:
110            return _file_descriptor_response(descriptor, request)
111
112    def _all_extension_numbers_of_type(self, request, containing_type):
113        try:
114            message_descriptor = self._pool.FindMessageTypeByName(
115                containing_type
116            )
117            extension_numbers = tuple(
118                sorted(
119                    extension.number
120                    for extension in self._pool.FindAllExtensions(
121                        message_descriptor
122                    )
123                )
124            )
125        except KeyError:
126            return _not_found_error(request)
127        else:
128            return _reflection_pb2.ServerReflectionResponse(
129                all_extension_numbers_response=_reflection_pb2.ExtensionNumberResponse(
130                    base_type_name=message_descriptor.full_name,
131                    extension_number=extension_numbers,
132                ),
133                original_request=request,
134            )
135
136    def _list_services(self, request):
137        return _reflection_pb2.ServerReflectionResponse(
138            list_services_response=_reflection_pb2.ListServiceResponse(
139                service=[
140                    _reflection_pb2.ServiceResponse(name=service_name)
141                    for service_name in self._service_names
142                ]
143            ),
144            original_request=request,
145        )
146
147
148__all__ = ["BaseReflectionServicer"]
149