1 /*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include <unordered_set>
20 #include <vector>
21
22 #include <grpcpp/grpcpp.h>
23
24 #include "src/cpp/ext/proto_server_reflection.h"
25
26 using grpc::Status;
27 using grpc::StatusCode;
28 using grpc::reflection::v1alpha::ErrorResponse;
29 using grpc::reflection::v1alpha::ExtensionNumberResponse;
30 using grpc::reflection::v1alpha::ExtensionRequest;
31 using grpc::reflection::v1alpha::ListServiceResponse;
32 using grpc::reflection::v1alpha::ServerReflectionRequest;
33 using grpc::reflection::v1alpha::ServerReflectionResponse;
34 using grpc::reflection::v1alpha::ServiceResponse;
35
36 namespace grpc {
37
ProtoServerReflection()38 ProtoServerReflection::ProtoServerReflection()
39 : descriptor_pool_(protobuf::DescriptorPool::generated_pool()) {}
40
SetServiceList(const std::vector<std::string> * services)41 void ProtoServerReflection::SetServiceList(
42 const std::vector<std::string>* services) {
43 services_ = services;
44 }
45
ServerReflectionInfo(ServerContext * context,ServerReaderWriter<ServerReflectionResponse,ServerReflectionRequest> * stream)46 Status ProtoServerReflection::ServerReflectionInfo(
47 ServerContext* context,
48 ServerReaderWriter<ServerReflectionResponse, ServerReflectionRequest>*
49 stream) {
50 ServerReflectionRequest request;
51 ServerReflectionResponse response;
52 Status status;
53 while (stream->Read(&request)) {
54 switch (request.message_request_case()) {
55 case ServerReflectionRequest::MessageRequestCase::kFileByFilename:
56 status = GetFileByName(context, request.file_by_filename(), &response);
57 break;
58 case ServerReflectionRequest::MessageRequestCase::kFileContainingSymbol:
59 status = GetFileContainingSymbol(
60 context, request.file_containing_symbol(), &response);
61 break;
62 case ServerReflectionRequest::MessageRequestCase::
63 kFileContainingExtension:
64 status = GetFileContainingExtension(
65 context, &request.file_containing_extension(), &response);
66 break;
67 case ServerReflectionRequest::MessageRequestCase::
68 kAllExtensionNumbersOfType:
69 status = GetAllExtensionNumbers(
70 context, request.all_extension_numbers_of_type(),
71 response.mutable_all_extension_numbers_response());
72 break;
73 case ServerReflectionRequest::MessageRequestCase::kListServices:
74 status =
75 ListService(context, response.mutable_list_services_response());
76 break;
77 default:
78 status = Status(StatusCode::UNIMPLEMENTED, "");
79 }
80
81 if (!status.ok()) {
82 FillErrorResponse(status, response.mutable_error_response());
83 }
84 response.set_valid_host(request.host());
85 response.set_allocated_original_request(
86 new ServerReflectionRequest(request));
87 stream->Write(response);
88 }
89
90 return Status::OK;
91 }
92
FillErrorResponse(const Status & status,ErrorResponse * error_response)93 void ProtoServerReflection::FillErrorResponse(const Status& status,
94 ErrorResponse* error_response) {
95 error_response->set_error_code(status.error_code());
96 error_response->set_error_message(status.error_message());
97 }
98
ListService(ServerContext *,ListServiceResponse * response)99 Status ProtoServerReflection::ListService(ServerContext* /*context*/,
100 ListServiceResponse* response) {
101 if (services_ == nullptr) {
102 return Status(StatusCode::NOT_FOUND, "Services not found.");
103 }
104 for (const auto& value : *services_) {
105 ServiceResponse* service_response = response->add_service();
106 service_response->set_name(value);
107 }
108 return Status::OK;
109 }
110
GetFileByName(ServerContext *,const std::string & file_name,ServerReflectionResponse * response)111 Status ProtoServerReflection::GetFileByName(
112 ServerContext* /*context*/, const std::string& file_name,
113 ServerReflectionResponse* response) {
114 if (descriptor_pool_ == nullptr) {
115 return Status::CANCELLED;
116 }
117
118 const protobuf::FileDescriptor* file_desc =
119 descriptor_pool_->FindFileByName(file_name);
120 if (file_desc == nullptr) {
121 return Status(StatusCode::NOT_FOUND, "File not found.");
122 }
123 std::unordered_set<std::string> seen_files;
124 FillFileDescriptorResponse(file_desc, response, &seen_files);
125 return Status::OK;
126 }
127
GetFileContainingSymbol(ServerContext *,const std::string & symbol,ServerReflectionResponse * response)128 Status ProtoServerReflection::GetFileContainingSymbol(
129 ServerContext* /*context*/, const std::string& symbol,
130 ServerReflectionResponse* response) {
131 if (descriptor_pool_ == nullptr) {
132 return Status::CANCELLED;
133 }
134
135 const protobuf::FileDescriptor* file_desc =
136 descriptor_pool_->FindFileContainingSymbol(symbol);
137 if (file_desc == nullptr) {
138 return Status(StatusCode::NOT_FOUND, "Symbol not found.");
139 }
140 std::unordered_set<std::string> seen_files;
141 FillFileDescriptorResponse(file_desc, response, &seen_files);
142 return Status::OK;
143 }
144
GetFileContainingExtension(ServerContext *,const ExtensionRequest * request,ServerReflectionResponse * response)145 Status ProtoServerReflection::GetFileContainingExtension(
146 ServerContext* /*context*/, const ExtensionRequest* request,
147 ServerReflectionResponse* response) {
148 if (descriptor_pool_ == nullptr) {
149 return Status::CANCELLED;
150 }
151
152 const protobuf::Descriptor* desc =
153 descriptor_pool_->FindMessageTypeByName(request->containing_type());
154 if (desc == nullptr) {
155 return Status(StatusCode::NOT_FOUND, "Type not found.");
156 }
157
158 const protobuf::FieldDescriptor* field_desc =
159 descriptor_pool_->FindExtensionByNumber(desc,
160 request->extension_number());
161 if (field_desc == nullptr) {
162 return Status(StatusCode::NOT_FOUND, "Extension not found.");
163 }
164 std::unordered_set<std::string> seen_files;
165 FillFileDescriptorResponse(field_desc->file(), response, &seen_files);
166 return Status::OK;
167 }
168
GetAllExtensionNumbers(ServerContext *,const std::string & type,ExtensionNumberResponse * response)169 Status ProtoServerReflection::GetAllExtensionNumbers(
170 ServerContext* /*context*/, const std::string& type,
171 ExtensionNumberResponse* response) {
172 if (descriptor_pool_ == nullptr) {
173 return Status::CANCELLED;
174 }
175
176 const protobuf::Descriptor* desc =
177 descriptor_pool_->FindMessageTypeByName(type);
178 if (desc == nullptr) {
179 return Status(StatusCode::NOT_FOUND, "Type not found.");
180 }
181
182 std::vector<const protobuf::FieldDescriptor*> extensions;
183 descriptor_pool_->FindAllExtensions(desc, &extensions);
184 for (const auto& value : extensions) {
185 response->add_extension_number(value->number());
186 }
187 response->set_base_type_name(type);
188 return Status::OK;
189 }
190
FillFileDescriptorResponse(const protobuf::FileDescriptor * file_desc,ServerReflectionResponse * response,std::unordered_set<std::string> * seen_files)191 void ProtoServerReflection::FillFileDescriptorResponse(
192 const protobuf::FileDescriptor* file_desc,
193 ServerReflectionResponse* response,
194 std::unordered_set<std::string>* seen_files) {
195 if (seen_files->find(file_desc->name()) != seen_files->end()) {
196 return;
197 }
198 seen_files->insert(file_desc->name());
199
200 protobuf::FileDescriptorProto file_desc_proto;
201 std::string data;
202 file_desc->CopyTo(&file_desc_proto);
203 file_desc_proto.SerializeToString(&data);
204 response->mutable_file_descriptor_response()->add_file_descriptor_proto(data);
205
206 for (int i = 0; i < file_desc->dependency_count(); ++i) {
207 FillFileDescriptorResponse(file_desc->dependency(i), response, seen_files);
208 }
209 }
210
211 } // namespace grpc
212