• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <unordered_set>
20 #include <vector>
21 
22 #include <grpcpp/grpcpp.h>
23 
24 #include "src/cpp/ext/proto_server_reflection.h"
25 
26 using grpc::Status;
27 using grpc::StatusCode;
28 using grpc::reflection::v1alpha::ErrorResponse;
29 using grpc::reflection::v1alpha::ExtensionNumberResponse;
30 using grpc::reflection::v1alpha::ExtensionRequest;
31 using grpc::reflection::v1alpha::ListServiceResponse;
32 using grpc::reflection::v1alpha::ServerReflectionRequest;
33 using grpc::reflection::v1alpha::ServerReflectionResponse;
34 using grpc::reflection::v1alpha::ServiceResponse;
35 
36 namespace grpc {
37 
ProtoServerReflection()38 ProtoServerReflection::ProtoServerReflection()
39     : descriptor_pool_(protobuf::DescriptorPool::generated_pool()) {}
40 
SetServiceList(const std::vector<std::string> * services)41 void ProtoServerReflection::SetServiceList(
42     const std::vector<std::string>* services) {
43   services_ = services;
44 }
45 
ServerReflectionInfo(ServerContext * context,ServerReaderWriter<ServerReflectionResponse,ServerReflectionRequest> * stream)46 Status ProtoServerReflection::ServerReflectionInfo(
47     ServerContext* context,
48     ServerReaderWriter<ServerReflectionResponse, ServerReflectionRequest>*
49         stream) {
50   ServerReflectionRequest request;
51   ServerReflectionResponse response;
52   Status status;
53   while (stream->Read(&request)) {
54     switch (request.message_request_case()) {
55       case ServerReflectionRequest::MessageRequestCase::kFileByFilename:
56         status = GetFileByName(context, request.file_by_filename(), &response);
57         break;
58       case ServerReflectionRequest::MessageRequestCase::kFileContainingSymbol:
59         status = GetFileContainingSymbol(
60             context, request.file_containing_symbol(), &response);
61         break;
62       case ServerReflectionRequest::MessageRequestCase::
63           kFileContainingExtension:
64         status = GetFileContainingExtension(
65             context, &request.file_containing_extension(), &response);
66         break;
67       case ServerReflectionRequest::MessageRequestCase::
68           kAllExtensionNumbersOfType:
69         status = GetAllExtensionNumbers(
70             context, request.all_extension_numbers_of_type(),
71             response.mutable_all_extension_numbers_response());
72         break;
73       case ServerReflectionRequest::MessageRequestCase::kListServices:
74         status =
75             ListService(context, response.mutable_list_services_response());
76         break;
77       default:
78         status = Status(StatusCode::UNIMPLEMENTED, "");
79     }
80 
81     if (!status.ok()) {
82       FillErrorResponse(status, response.mutable_error_response());
83     }
84     response.set_valid_host(request.host());
85     response.set_allocated_original_request(
86         new ServerReflectionRequest(request));
87     stream->Write(response);
88   }
89 
90   return Status::OK;
91 }
92 
FillErrorResponse(const Status & status,ErrorResponse * error_response)93 void ProtoServerReflection::FillErrorResponse(const Status& status,
94                                               ErrorResponse* error_response) {
95   error_response->set_error_code(status.error_code());
96   error_response->set_error_message(status.error_message());
97 }
98 
ListService(ServerContext *,ListServiceResponse * response)99 Status ProtoServerReflection::ListService(ServerContext* /*context*/,
100                                           ListServiceResponse* response) {
101   if (services_ == nullptr) {
102     return Status(StatusCode::NOT_FOUND, "Services not found.");
103   }
104   for (const auto& value : *services_) {
105     ServiceResponse* service_response = response->add_service();
106     service_response->set_name(value);
107   }
108   return Status::OK;
109 }
110 
GetFileByName(ServerContext *,const std::string & file_name,ServerReflectionResponse * response)111 Status ProtoServerReflection::GetFileByName(
112     ServerContext* /*context*/, const std::string& file_name,
113     ServerReflectionResponse* response) {
114   if (descriptor_pool_ == nullptr) {
115     return Status::CANCELLED;
116   }
117 
118   const protobuf::FileDescriptor* file_desc =
119       descriptor_pool_->FindFileByName(file_name);
120   if (file_desc == nullptr) {
121     return Status(StatusCode::NOT_FOUND, "File not found.");
122   }
123   std::unordered_set<std::string> seen_files;
124   FillFileDescriptorResponse(file_desc, response, &seen_files);
125   return Status::OK;
126 }
127 
GetFileContainingSymbol(ServerContext *,const std::string & symbol,ServerReflectionResponse * response)128 Status ProtoServerReflection::GetFileContainingSymbol(
129     ServerContext* /*context*/, const std::string& symbol,
130     ServerReflectionResponse* response) {
131   if (descriptor_pool_ == nullptr) {
132     return Status::CANCELLED;
133   }
134 
135   const protobuf::FileDescriptor* file_desc =
136       descriptor_pool_->FindFileContainingSymbol(symbol);
137   if (file_desc == nullptr) {
138     return Status(StatusCode::NOT_FOUND, "Symbol not found.");
139   }
140   std::unordered_set<std::string> seen_files;
141   FillFileDescriptorResponse(file_desc, response, &seen_files);
142   return Status::OK;
143 }
144 
GetFileContainingExtension(ServerContext *,const ExtensionRequest * request,ServerReflectionResponse * response)145 Status ProtoServerReflection::GetFileContainingExtension(
146     ServerContext* /*context*/, const ExtensionRequest* request,
147     ServerReflectionResponse* response) {
148   if (descriptor_pool_ == nullptr) {
149     return Status::CANCELLED;
150   }
151 
152   const protobuf::Descriptor* desc =
153       descriptor_pool_->FindMessageTypeByName(request->containing_type());
154   if (desc == nullptr) {
155     return Status(StatusCode::NOT_FOUND, "Type not found.");
156   }
157 
158   const protobuf::FieldDescriptor* field_desc =
159       descriptor_pool_->FindExtensionByNumber(desc,
160                                               request->extension_number());
161   if (field_desc == nullptr) {
162     return Status(StatusCode::NOT_FOUND, "Extension not found.");
163   }
164   std::unordered_set<std::string> seen_files;
165   FillFileDescriptorResponse(field_desc->file(), response, &seen_files);
166   return Status::OK;
167 }
168 
GetAllExtensionNumbers(ServerContext *,const std::string & type,ExtensionNumberResponse * response)169 Status ProtoServerReflection::GetAllExtensionNumbers(
170     ServerContext* /*context*/, const std::string& type,
171     ExtensionNumberResponse* response) {
172   if (descriptor_pool_ == nullptr) {
173     return Status::CANCELLED;
174   }
175 
176   const protobuf::Descriptor* desc =
177       descriptor_pool_->FindMessageTypeByName(type);
178   if (desc == nullptr) {
179     return Status(StatusCode::NOT_FOUND, "Type not found.");
180   }
181 
182   std::vector<const protobuf::FieldDescriptor*> extensions;
183   descriptor_pool_->FindAllExtensions(desc, &extensions);
184   for (const auto& value : extensions) {
185     response->add_extension_number(value->number());
186   }
187   response->set_base_type_name(type);
188   return Status::OK;
189 }
190 
FillFileDescriptorResponse(const protobuf::FileDescriptor * file_desc,ServerReflectionResponse * response,std::unordered_set<std::string> * seen_files)191 void ProtoServerReflection::FillFileDescriptorResponse(
192     const protobuf::FileDescriptor* file_desc,
193     ServerReflectionResponse* response,
194     std::unordered_set<std::string>* seen_files) {
195   if (seen_files->find(file_desc->name()) != seen_files->end()) {
196     return;
197   }
198   seen_files->insert(file_desc->name());
199 
200   protobuf::FileDescriptorProto file_desc_proto;
201   std::string data;
202   file_desc->CopyTo(&file_desc_proto);
203   file_desc_proto.SerializeToString(&data);
204   response->mutable_file_descriptor_response()->add_file_descriptor_proto(data);
205 
206   for (int i = 0; i < file_desc->dependency_count(); ++i) {
207     FillFileDescriptorResponse(file_desc->dependency(i), response, seen_files);
208   }
209 }
210 
211 }  // namespace grpc
212