1 /*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include <unordered_set>
20 #include <vector>
21
22 #include <grpcpp/grpcpp.h>
23
24 #include "src/cpp/ext/proto_server_reflection.h"
25
26 using grpc::Status;
27 using grpc::StatusCode;
28 using grpc::reflection::v1alpha::ErrorResponse;
29 using grpc::reflection::v1alpha::ExtensionNumberResponse;
30 using grpc::reflection::v1alpha::ExtensionRequest;
31 using grpc::reflection::v1alpha::FileDescriptorResponse;
32 using grpc::reflection::v1alpha::ListServiceResponse;
33 using grpc::reflection::v1alpha::ServerReflectionRequest;
34 using grpc::reflection::v1alpha::ServerReflectionResponse;
35 using grpc::reflection::v1alpha::ServiceResponse;
36
37 namespace grpc {
38
ProtoServerReflection()39 ProtoServerReflection::ProtoServerReflection()
40 : descriptor_pool_(protobuf::DescriptorPool::generated_pool()) {}
41
SetServiceList(const std::vector<grpc::string> * services)42 void ProtoServerReflection::SetServiceList(
43 const std::vector<grpc::string>* services) {
44 services_ = services;
45 }
46
ServerReflectionInfo(ServerContext * context,ServerReaderWriter<ServerReflectionResponse,ServerReflectionRequest> * stream)47 Status ProtoServerReflection::ServerReflectionInfo(
48 ServerContext* context,
49 ServerReaderWriter<ServerReflectionResponse, ServerReflectionRequest>*
50 stream) {
51 ServerReflectionRequest request;
52 ServerReflectionResponse response;
53 Status status;
54 while (stream->Read(&request)) {
55 switch (request.message_request_case()) {
56 case ServerReflectionRequest::MessageRequestCase::kFileByFilename:
57 status = GetFileByName(context, request.file_by_filename(), &response);
58 break;
59 case ServerReflectionRequest::MessageRequestCase::kFileContainingSymbol:
60 status = GetFileContainingSymbol(
61 context, request.file_containing_symbol(), &response);
62 break;
63 case ServerReflectionRequest::MessageRequestCase::
64 kFileContainingExtension:
65 status = GetFileContainingExtension(
66 context, &request.file_containing_extension(), &response);
67 break;
68 case ServerReflectionRequest::MessageRequestCase::
69 kAllExtensionNumbersOfType:
70 status = GetAllExtensionNumbers(
71 context, request.all_extension_numbers_of_type(),
72 response.mutable_all_extension_numbers_response());
73 break;
74 case ServerReflectionRequest::MessageRequestCase::kListServices:
75 status =
76 ListService(context, response.mutable_list_services_response());
77 break;
78 default:
79 status = Status(StatusCode::UNIMPLEMENTED, "");
80 }
81
82 if (!status.ok()) {
83 FillErrorResponse(status, response.mutable_error_response());
84 }
85 response.set_valid_host(request.host());
86 response.set_allocated_original_request(
87 new ServerReflectionRequest(request));
88 stream->Write(response);
89 }
90
91 return Status::OK;
92 }
93
FillErrorResponse(const Status & status,ErrorResponse * error_response)94 void ProtoServerReflection::FillErrorResponse(const Status& status,
95 ErrorResponse* error_response) {
96 error_response->set_error_code(status.error_code());
97 error_response->set_error_message(status.error_message());
98 }
99
ListService(ServerContext * context,ListServiceResponse * response)100 Status ProtoServerReflection::ListService(ServerContext* context,
101 ListServiceResponse* response) {
102 if (services_ == nullptr) {
103 return Status(StatusCode::NOT_FOUND, "Services not found.");
104 }
105 for (auto it = services_->begin(); it != services_->end(); ++it) {
106 ServiceResponse* service_response = response->add_service();
107 service_response->set_name(*it);
108 }
109 return Status::OK;
110 }
111
GetFileByName(ServerContext * context,const grpc::string & filename,ServerReflectionResponse * response)112 Status ProtoServerReflection::GetFileByName(
113 ServerContext* context, const grpc::string& filename,
114 ServerReflectionResponse* response) {
115 if (descriptor_pool_ == nullptr) {
116 return Status::CANCELLED;
117 }
118
119 const protobuf::FileDescriptor* file_desc =
120 descriptor_pool_->FindFileByName(filename);
121 if (file_desc == nullptr) {
122 return Status(StatusCode::NOT_FOUND, "File not found.");
123 }
124 std::unordered_set<grpc::string> seen_files;
125 FillFileDescriptorResponse(file_desc, response, &seen_files);
126 return Status::OK;
127 }
128
GetFileContainingSymbol(ServerContext * context,const grpc::string & symbol,ServerReflectionResponse * response)129 Status ProtoServerReflection::GetFileContainingSymbol(
130 ServerContext* context, const grpc::string& symbol,
131 ServerReflectionResponse* response) {
132 if (descriptor_pool_ == nullptr) {
133 return Status::CANCELLED;
134 }
135
136 const protobuf::FileDescriptor* file_desc =
137 descriptor_pool_->FindFileContainingSymbol(symbol);
138 if (file_desc == nullptr) {
139 return Status(StatusCode::NOT_FOUND, "Symbol not found.");
140 }
141 std::unordered_set<grpc::string> seen_files;
142 FillFileDescriptorResponse(file_desc, response, &seen_files);
143 return Status::OK;
144 }
145
GetFileContainingExtension(ServerContext * context,const ExtensionRequest * request,ServerReflectionResponse * response)146 Status ProtoServerReflection::GetFileContainingExtension(
147 ServerContext* context, const ExtensionRequest* request,
148 ServerReflectionResponse* response) {
149 if (descriptor_pool_ == nullptr) {
150 return Status::CANCELLED;
151 }
152
153 const protobuf::Descriptor* desc =
154 descriptor_pool_->FindMessageTypeByName(request->containing_type());
155 if (desc == nullptr) {
156 return Status(StatusCode::NOT_FOUND, "Type not found.");
157 }
158
159 const protobuf::FieldDescriptor* field_desc =
160 descriptor_pool_->FindExtensionByNumber(desc,
161 request->extension_number());
162 if (field_desc == nullptr) {
163 return Status(StatusCode::NOT_FOUND, "Extension not found.");
164 }
165 std::unordered_set<grpc::string> seen_files;
166 FillFileDescriptorResponse(field_desc->file(), response, &seen_files);
167 return Status::OK;
168 }
169
GetAllExtensionNumbers(ServerContext * context,const grpc::string & type,ExtensionNumberResponse * response)170 Status ProtoServerReflection::GetAllExtensionNumbers(
171 ServerContext* context, const grpc::string& type,
172 ExtensionNumberResponse* response) {
173 if (descriptor_pool_ == nullptr) {
174 return Status::CANCELLED;
175 }
176
177 const protobuf::Descriptor* desc =
178 descriptor_pool_->FindMessageTypeByName(type);
179 if (desc == nullptr) {
180 return Status(StatusCode::NOT_FOUND, "Type not found.");
181 }
182
183 std::vector<const protobuf::FieldDescriptor*> extensions;
184 descriptor_pool_->FindAllExtensions(desc, &extensions);
185 for (auto it = extensions.begin(); it != extensions.end(); it++) {
186 response->add_extension_number((*it)->number());
187 }
188 response->set_base_type_name(type);
189 return Status::OK;
190 }
191
FillFileDescriptorResponse(const protobuf::FileDescriptor * file_desc,ServerReflectionResponse * response,std::unordered_set<grpc::string> * seen_files)192 void ProtoServerReflection::FillFileDescriptorResponse(
193 const protobuf::FileDescriptor* file_desc,
194 ServerReflectionResponse* response,
195 std::unordered_set<grpc::string>* seen_files) {
196 if (seen_files->find(file_desc->name()) != seen_files->end()) {
197 return;
198 }
199 seen_files->insert(file_desc->name());
200
201 protobuf::FileDescriptorProto file_desc_proto;
202 grpc::string data;
203 file_desc->CopyTo(&file_desc_proto);
204 file_desc_proto.SerializeToString(&data);
205 response->mutable_file_descriptor_response()->add_file_descriptor_proto(data);
206
207 for (int i = 0; i < file_desc->dependency_count(); ++i) {
208 FillFileDescriptorResponse(file_desc->dependency(i), response, seen_files);
209 }
210 }
211
212 } // namespace grpc
213