1# Copyright 2020 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Base implementation of reflection servicer.""" 15 16from google.protobuf import descriptor_pb2 17from google.protobuf import descriptor_pool 18import grpc 19from grpc_reflection.v1alpha import reflection_pb2 as _reflection_pb2 20from grpc_reflection.v1alpha import reflection_pb2_grpc as _reflection_pb2_grpc 21 22_POOL = descriptor_pool.Default() 23 24 25def _not_found_error(original_request): 26 return _reflection_pb2.ServerReflectionResponse( 27 error_response=_reflection_pb2.ErrorResponse( 28 error_code=grpc.StatusCode.NOT_FOUND.value[0], 29 error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), 30 ), 31 original_request=original_request, 32 ) 33 34 35def _collect_transitive_dependencies(descriptor, seen_files): 36 seen_files.update({descriptor.name: descriptor}) 37 for dependency in descriptor.dependencies: 38 if not dependency.name in seen_files: 39 # descriptors cannot have circular dependencies 40 _collect_transitive_dependencies(dependency, seen_files) 41 42 43def _file_descriptor_response(descriptor, original_request): 44 # collect all dependencies 45 descriptors = {} 46 _collect_transitive_dependencies(descriptor, descriptors) 47 48 # serialize all descriptors 49 serialized_proto_list = [] 50 for d_key in descriptors: 51 proto = descriptor_pb2.FileDescriptorProto() 52 descriptors[d_key].CopyToProto(proto) 53 serialized_proto_list.append(proto.SerializeToString()) 54 55 return _reflection_pb2.ServerReflectionResponse( 56 file_descriptor_response=_reflection_pb2.FileDescriptorResponse( 57 file_descriptor_proto=(serialized_proto_list) 58 ), 59 original_request=original_request, 60 ) 61 62 63class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer): 64 """Base class for reflection servicer.""" 65 66 def __init__(self, service_names, pool=None): 67 """Constructor. 68 69 Args: 70 service_names: Iterable of fully-qualified service names available. 71 pool: An optional DescriptorPool instance. 72 """ 73 self._service_names = tuple(sorted(service_names)) 74 self._pool = _POOL if pool is None else pool 75 76 def _file_by_filename(self, request, filename): 77 try: 78 descriptor = self._pool.FindFileByName(filename) 79 except KeyError: 80 return _not_found_error(request) 81 else: 82 return _file_descriptor_response(descriptor, request) 83 84 def _file_containing_symbol(self, request, fully_qualified_name): 85 try: 86 descriptor = self._pool.FindFileContainingSymbol( 87 fully_qualified_name 88 ) 89 except KeyError: 90 return _not_found_error(request) 91 else: 92 return _file_descriptor_response(descriptor, request) 93 94 def _file_containing_extension( 95 self, request, containing_type, extension_number 96 ): 97 try: 98 message_descriptor = self._pool.FindMessageTypeByName( 99 containing_type 100 ) 101 extension_descriptor = self._pool.FindExtensionByNumber( 102 message_descriptor, extension_number 103 ) 104 descriptor = self._pool.FindFileContainingSymbol( 105 extension_descriptor.full_name 106 ) 107 except KeyError: 108 return _not_found_error(request) 109 else: 110 return _file_descriptor_response(descriptor, request) 111 112 def _all_extension_numbers_of_type(self, request, containing_type): 113 try: 114 message_descriptor = self._pool.FindMessageTypeByName( 115 containing_type 116 ) 117 extension_numbers = tuple( 118 sorted( 119 extension.number 120 for extension in self._pool.FindAllExtensions( 121 message_descriptor 122 ) 123 ) 124 ) 125 except KeyError: 126 return _not_found_error(request) 127 else: 128 return _reflection_pb2.ServerReflectionResponse( 129 all_extension_numbers_response=_reflection_pb2.ExtensionNumberResponse( 130 base_type_name=message_descriptor.full_name, 131 extension_number=extension_numbers, 132 ), 133 original_request=request, 134 ) 135 136 def _list_services(self, request): 137 return _reflection_pb2.ServerReflectionResponse( 138 list_services_response=_reflection_pb2.ListServiceResponse( 139 service=[ 140 _reflection_pb2.ServiceResponse(name=service_name) 141 for service_name in self._service_names 142 ] 143 ), 144 original_request=request, 145 ) 146 147 148__all__ = ["BaseReflectionServicer"] 149