1 /*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include "test/cpp/util/proto_reflection_descriptor_database.h"
20
21 #include <vector>
22
23 #include <grpc/support/log.h>
24
25 using grpc::reflection::v1alpha::ErrorResponse;
26 using grpc::reflection::v1alpha::ListServiceResponse;
27 using grpc::reflection::v1alpha::ServerReflection;
28 using grpc::reflection::v1alpha::ServerReflectionRequest;
29 using grpc::reflection::v1alpha::ServerReflectionResponse;
30
31 namespace grpc {
32
ProtoReflectionDescriptorDatabase(std::unique_ptr<ServerReflection::Stub> stub)33 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
34 std::unique_ptr<ServerReflection::Stub> stub)
35 : stub_(std::move(stub)) {}
36
ProtoReflectionDescriptorDatabase(const std::shared_ptr<grpc::Channel> & channel)37 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
38 const std::shared_ptr<grpc::Channel>& channel)
39 : stub_(ServerReflection::NewStub(channel)) {}
40
~ProtoReflectionDescriptorDatabase()41 ProtoReflectionDescriptorDatabase::~ProtoReflectionDescriptorDatabase() {
42 if (stream_) {
43 stream_->WritesDone();
44 Status status = stream_->Finish();
45 if (!status.ok()) {
46 if (status.error_code() == StatusCode::UNIMPLEMENTED) {
47 gpr_log(GPR_INFO,
48 "Reflection request not implemented; "
49 "is the ServerReflection service enabled?");
50 }
51 gpr_log(GPR_INFO,
52 "ServerReflectionInfo rpc failed. Error code: %d, details: %s",
53 static_cast<int>(status.error_code()),
54 status.error_message().c_str());
55 }
56 }
57 }
58
FindFileByName(const string & filename,protobuf::FileDescriptorProto * output)59 bool ProtoReflectionDescriptorDatabase::FindFileByName(
60 const string& filename, protobuf::FileDescriptorProto* output) {
61 if (cached_db_.FindFileByName(filename, output)) {
62 return true;
63 }
64
65 if (known_files_.find(filename) != known_files_.end()) {
66 return false;
67 }
68
69 ServerReflectionRequest request;
70 request.set_file_by_filename(filename);
71 ServerReflectionResponse response;
72
73 if (!DoOneRequest(request, response)) {
74 return false;
75 }
76
77 if (response.message_response_case() ==
78 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
79 AddFileFromResponse(response.file_descriptor_response());
80 } else if (response.message_response_case() ==
81 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
82 const ErrorResponse& error = response.error_response();
83 if (error.error_code() == StatusCode::NOT_FOUND) {
84 gpr_log(GPR_INFO, "NOT_FOUND from server for FindFileByName(%s)",
85 filename.c_str());
86 } else {
87 gpr_log(GPR_INFO,
88 "Error on FindFileByName(%s)\n\tError code: %d\n"
89 "\tError Message: %s",
90 filename.c_str(), error.error_code(),
91 error.error_message().c_str());
92 }
93 } else {
94 gpr_log(
95 GPR_INFO,
96 "Error on FindFileByName(%s) response type\n"
97 "\tExpecting: %d\n\tReceived: %d",
98 filename.c_str(),
99 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
100 response.message_response_case());
101 }
102
103 return cached_db_.FindFileByName(filename, output);
104 }
105
FindFileContainingSymbol(const string & symbol_name,protobuf::FileDescriptorProto * output)106 bool ProtoReflectionDescriptorDatabase::FindFileContainingSymbol(
107 const string& symbol_name, protobuf::FileDescriptorProto* output) {
108 if (cached_db_.FindFileContainingSymbol(symbol_name, output)) {
109 return true;
110 }
111
112 if (missing_symbols_.find(symbol_name) != missing_symbols_.end()) {
113 return false;
114 }
115
116 ServerReflectionRequest request;
117 request.set_file_containing_symbol(symbol_name);
118 ServerReflectionResponse response;
119
120 if (!DoOneRequest(request, response)) {
121 return false;
122 }
123
124 if (response.message_response_case() ==
125 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
126 AddFileFromResponse(response.file_descriptor_response());
127 } else if (response.message_response_case() ==
128 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
129 const ErrorResponse& error = response.error_response();
130 if (error.error_code() == StatusCode::NOT_FOUND) {
131 missing_symbols_.insert(symbol_name);
132 gpr_log(GPR_INFO,
133 "NOT_FOUND from server for FindFileContainingSymbol(%s)",
134 symbol_name.c_str());
135 } else {
136 gpr_log(GPR_INFO,
137 "Error on FindFileContainingSymbol(%s)\n"
138 "\tError code: %d\n\tError Message: %s",
139 symbol_name.c_str(), error.error_code(),
140 error.error_message().c_str());
141 }
142 } else {
143 gpr_log(
144 GPR_INFO,
145 "Error on FindFileContainingSymbol(%s) response type\n"
146 "\tExpecting: %d\n\tReceived: %d",
147 symbol_name.c_str(),
148 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
149 response.message_response_case());
150 }
151 return cached_db_.FindFileContainingSymbol(symbol_name, output);
152 }
153
FindFileContainingExtension(const string & containing_type,int field_number,protobuf::FileDescriptorProto * output)154 bool ProtoReflectionDescriptorDatabase::FindFileContainingExtension(
155 const string& containing_type, int field_number,
156 protobuf::FileDescriptorProto* output) {
157 if (cached_db_.FindFileContainingExtension(containing_type, field_number,
158 output)) {
159 return true;
160 }
161
162 if (missing_extensions_.find(containing_type) != missing_extensions_.end() &&
163 missing_extensions_[containing_type].find(field_number) !=
164 missing_extensions_[containing_type].end()) {
165 gpr_log(GPR_INFO, "nested map.");
166 return false;
167 }
168
169 ServerReflectionRequest request;
170 request.mutable_file_containing_extension()->set_containing_type(
171 containing_type);
172 request.mutable_file_containing_extension()->set_extension_number(
173 field_number);
174 ServerReflectionResponse response;
175
176 if (!DoOneRequest(request, response)) {
177 return false;
178 }
179
180 if (response.message_response_case() ==
181 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
182 AddFileFromResponse(response.file_descriptor_response());
183 } else if (response.message_response_case() ==
184 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
185 const ErrorResponse& error = response.error_response();
186 if (error.error_code() == StatusCode::NOT_FOUND) {
187 if (missing_extensions_.find(containing_type) ==
188 missing_extensions_.end()) {
189 missing_extensions_[containing_type] = {};
190 }
191 missing_extensions_[containing_type].insert(field_number);
192 gpr_log(GPR_INFO,
193 "NOT_FOUND from server for FindFileContainingExtension(%s, %d)",
194 containing_type.c_str(), field_number);
195 } else {
196 gpr_log(GPR_INFO,
197 "Error on FindFileContainingExtension(%s, %d)\n"
198 "\tError code: %d\n\tError Message: %s",
199 containing_type.c_str(), field_number, error.error_code(),
200 error.error_message().c_str());
201 }
202 } else {
203 gpr_log(
204 GPR_INFO,
205 "Error on FindFileContainingExtension(%s, %d) response type\n"
206 "\tExpecting: %d\n\tReceived: %d",
207 containing_type.c_str(), field_number,
208 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
209 response.message_response_case());
210 }
211
212 return cached_db_.FindFileContainingExtension(containing_type, field_number,
213 output);
214 }
215
FindAllExtensionNumbers(const string & extendee_type,std::vector<int> * output)216 bool ProtoReflectionDescriptorDatabase::FindAllExtensionNumbers(
217 const string& extendee_type, std::vector<int>* output) {
218 if (cached_extension_numbers_.find(extendee_type) !=
219 cached_extension_numbers_.end()) {
220 *output = cached_extension_numbers_[extendee_type];
221 return true;
222 }
223
224 ServerReflectionRequest request;
225 request.set_all_extension_numbers_of_type(extendee_type);
226 ServerReflectionResponse response;
227
228 if (!DoOneRequest(request, response)) {
229 return false;
230 }
231
232 if (response.message_response_case() ==
233 ServerReflectionResponse::MessageResponseCase::
234 kAllExtensionNumbersResponse) {
235 auto number = response.all_extension_numbers_response().extension_number();
236 *output = std::vector<int>(number.begin(), number.end());
237 cached_extension_numbers_[extendee_type] = *output;
238 return true;
239 } else if (response.message_response_case() ==
240 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
241 const ErrorResponse& error = response.error_response();
242 if (error.error_code() == StatusCode::NOT_FOUND) {
243 gpr_log(GPR_INFO, "NOT_FOUND from server for FindAllExtensionNumbers(%s)",
244 extendee_type.c_str());
245 } else {
246 gpr_log(GPR_INFO,
247 "Error on FindAllExtensionNumbersExtension(%s)\n"
248 "\tError code: %d\n\tError Message: %s",
249 extendee_type.c_str(), error.error_code(),
250 error.error_message().c_str());
251 }
252 }
253 return false;
254 }
255
GetServices(std::vector<grpc::string> * output)256 bool ProtoReflectionDescriptorDatabase::GetServices(
257 std::vector<grpc::string>* output) {
258 ServerReflectionRequest request;
259 request.set_list_services("");
260 ServerReflectionResponse response;
261
262 if (!DoOneRequest(request, response)) {
263 return false;
264 }
265
266 if (response.message_response_case() ==
267 ServerReflectionResponse::MessageResponseCase::kListServicesResponse) {
268 const ListServiceResponse& ls_response = response.list_services_response();
269 for (int i = 0; i < ls_response.service_size(); ++i) {
270 (*output).push_back(ls_response.service(i).name());
271 }
272 return true;
273 } else if (response.message_response_case() ==
274 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
275 const ErrorResponse& error = response.error_response();
276 gpr_log(GPR_INFO,
277 "Error on GetServices()\n\tError code: %d\n"
278 "\tError Message: %s",
279 error.error_code(), error.error_message().c_str());
280 } else {
281 gpr_log(
282 GPR_INFO,
283 "Error on GetServices() response type\n\tExpecting: %d\n\tReceived: %d",
284 ServerReflectionResponse::MessageResponseCase::kListServicesResponse,
285 response.message_response_case());
286 }
287 return false;
288 }
289
290 const protobuf::FileDescriptorProto
ParseFileDescriptorProtoResponse(const grpc::string & byte_fd_proto)291 ProtoReflectionDescriptorDatabase::ParseFileDescriptorProtoResponse(
292 const grpc::string& byte_fd_proto) {
293 protobuf::FileDescriptorProto file_desc_proto;
294 file_desc_proto.ParseFromString(byte_fd_proto);
295 return file_desc_proto;
296 }
297
AddFileFromResponse(const grpc::reflection::v1alpha::FileDescriptorResponse & response)298 void ProtoReflectionDescriptorDatabase::AddFileFromResponse(
299 const grpc::reflection::v1alpha::FileDescriptorResponse& response) {
300 for (int i = 0; i < response.file_descriptor_proto_size(); ++i) {
301 const protobuf::FileDescriptorProto file_proto =
302 ParseFileDescriptorProtoResponse(response.file_descriptor_proto(i));
303 if (known_files_.find(file_proto.name()) == known_files_.end()) {
304 known_files_.insert(file_proto.name());
305 cached_db_.Add(file_proto);
306 }
307 }
308 }
309
310 const std::shared_ptr<ProtoReflectionDescriptorDatabase::ClientStream>
GetStream()311 ProtoReflectionDescriptorDatabase::GetStream() {
312 if (!stream_) {
313 stream_ = stub_->ServerReflectionInfo(&ctx_);
314 }
315 return stream_;
316 }
317
DoOneRequest(const ServerReflectionRequest & request,ServerReflectionResponse & response)318 bool ProtoReflectionDescriptorDatabase::DoOneRequest(
319 const ServerReflectionRequest& request,
320 ServerReflectionResponse& response) {
321 bool success = false;
322 stream_mutex_.lock();
323 if (GetStream()->Write(request) && GetStream()->Read(&response)) {
324 success = true;
325 }
326 stream_mutex_.unlock();
327 return success;
328 }
329
330 } // namespace grpc
331