• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2016 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include "test/cpp/util/proto_reflection_descriptor_database.h"
20 
21 #include <vector>
22 
23 #include "absl/log/log.h"
24 #include "src/core/util/crash.h"
25 
26 using grpc::reflection::v1alpha::ErrorResponse;
27 using grpc::reflection::v1alpha::ListServiceResponse;
28 using grpc::reflection::v1alpha::ServerReflection;
29 using grpc::reflection::v1alpha::ServerReflectionRequest;
30 using grpc::reflection::v1alpha::ServerReflectionResponse;
31 
32 namespace grpc {
33 
ProtoReflectionDescriptorDatabase(std::unique_ptr<ServerReflection::Stub> stub)34 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
35     std::unique_ptr<ServerReflection::Stub> stub)
36     : stub_(std::move(stub)) {}
37 
ProtoReflectionDescriptorDatabase(const std::shared_ptr<grpc::ChannelInterface> & channel)38 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
39     const std::shared_ptr<grpc::ChannelInterface>& channel)
40     : stub_(ServerReflection::NewStub(channel)) {}
41 
~ProtoReflectionDescriptorDatabase()42 ProtoReflectionDescriptorDatabase::~ProtoReflectionDescriptorDatabase() {
43   if (stream_) {
44     stream_->WritesDone();
45     Status status = stream_->Finish();
46     if (!status.ok()) {
47       if (status.error_code() == StatusCode::UNIMPLEMENTED) {
48         fprintf(stderr,
49                 "Reflection request not implemented; "
50                 "is the ServerReflection service enabled?\n");
51       } else {
52         fprintf(stderr,
53                 "ServerReflectionInfo rpc failed. Error code: %d, message: %s, "
54                 "debug info: %s\n",
55                 static_cast<int>(status.error_code()),
56                 status.error_message().c_str(),
57                 ctx_.debug_error_string().c_str());
58       }
59     }
60   }
61 }
62 
FindFileByName(const string & filename,protobuf::FileDescriptorProto * output)63 bool ProtoReflectionDescriptorDatabase::FindFileByName(
64     const string& filename, protobuf::FileDescriptorProto* output) {
65   if (cached_db_.FindFileByName(filename, output)) {
66     return true;
67   }
68 
69   if (known_files_.find(filename) != known_files_.end()) {
70     return false;
71   }
72 
73   ServerReflectionRequest request;
74   request.set_file_by_filename(filename);
75   ServerReflectionResponse response;
76 
77   if (!DoOneRequest(request, response)) {
78     return false;
79   }
80 
81   if (response.message_response_case() ==
82       ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
83     AddFileFromResponse(response.file_descriptor_response());
84   } else if (response.message_response_case() ==
85              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
86     const ErrorResponse& error = response.error_response();
87     if (error.error_code() == StatusCode::NOT_FOUND) {
88       LOG(INFO) << "NOT_FOUND from server for FindFileByName(" << filename
89                 << ")";
90     } else {
91       LOG(INFO) << "Error on FindFileByName(" << filename
92                 << ")\n\tError code: " << error.error_code()
93                 << "\n\tError Message: " << error.error_message();
94     }
95   } else {
96     LOG(INFO) << "Error on FindFileByName(" << filename
97               << ") response type\n\tExpecting: "
98               << ServerReflectionResponse::MessageResponseCase::
99                      kFileDescriptorResponse
100               << "\n\tReceived: " << response.message_response_case();
101   }
102 
103   return cached_db_.FindFileByName(filename, output);
104 }
105 
FindFileContainingSymbol(const string & symbol_name,protobuf::FileDescriptorProto * output)106 bool ProtoReflectionDescriptorDatabase::FindFileContainingSymbol(
107     const string& symbol_name, protobuf::FileDescriptorProto* output) {
108   if (cached_db_.FindFileContainingSymbol(symbol_name, output)) {
109     return true;
110   }
111 
112   if (missing_symbols_.find(symbol_name) != missing_symbols_.end()) {
113     return false;
114   }
115 
116   ServerReflectionRequest request;
117   request.set_file_containing_symbol(symbol_name);
118   ServerReflectionResponse response;
119 
120   if (!DoOneRequest(request, response)) {
121     return false;
122   }
123 
124   if (response.message_response_case() ==
125       ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
126     AddFileFromResponse(response.file_descriptor_response());
127   } else if (response.message_response_case() ==
128              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
129     const ErrorResponse& error = response.error_response();
130     if (error.error_code() == StatusCode::NOT_FOUND) {
131       missing_symbols_.insert(symbol_name);
132       LOG(INFO) << "NOT_FOUND from server for FindFileContainingSymbol("
133                 << symbol_name << ")";
134     } else {
135       LOG(INFO) << "Error on FindFileContainingSymbol(" << symbol_name
136                 << ")\n\tError code: " << error.error_code()
137                 << "\n\tError Message: " << error.error_message();
138     }
139   } else {
140     LOG(INFO) << "Error on FindFileContainingSymbol(" << symbol_name
141               << ") response type\n\tExpecting: "
142               << ServerReflectionResponse::MessageResponseCase::
143                      kFileDescriptorResponse
144               << "\n\tReceived: " << response.message_response_case();
145   }
146   return cached_db_.FindFileContainingSymbol(symbol_name, output);
147 }
148 
FindFileContainingExtension(const string & containing_type,int field_number,protobuf::FileDescriptorProto * output)149 bool ProtoReflectionDescriptorDatabase::FindFileContainingExtension(
150     const string& containing_type, int field_number,
151     protobuf::FileDescriptorProto* output) {
152   if (cached_db_.FindFileContainingExtension(containing_type, field_number,
153                                              output)) {
154     return true;
155   }
156 
157   if (missing_extensions_.find(containing_type) != missing_extensions_.end() &&
158       missing_extensions_[containing_type].find(field_number) !=
159           missing_extensions_[containing_type].end()) {
160     LOG(INFO) << "nested map.";
161     return false;
162   }
163 
164   ServerReflectionRequest request;
165   request.mutable_file_containing_extension()->set_containing_type(
166       containing_type);
167   request.mutable_file_containing_extension()->set_extension_number(
168       field_number);
169   ServerReflectionResponse response;
170 
171   if (!DoOneRequest(request, response)) {
172     return false;
173   }
174 
175   if (response.message_response_case() ==
176       ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
177     AddFileFromResponse(response.file_descriptor_response());
178   } else if (response.message_response_case() ==
179              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
180     const ErrorResponse& error = response.error_response();
181     if (error.error_code() == StatusCode::NOT_FOUND) {
182       if (missing_extensions_.find(containing_type) ==
183           missing_extensions_.end()) {
184         missing_extensions_[containing_type] = {};
185       }
186       missing_extensions_[containing_type].insert(field_number);
187       LOG(INFO) << "NOT_FOUND from server for FindFileContainingExtension("
188                 << containing_type << ", " << field_number << ")";
189     } else {
190       LOG(INFO) << "Error on FindFileContainingExtension(" << containing_type
191                 << ", " << field_number
192                 << ")\n\tError code: " << error.error_code()
193                 << "\n\tError Message: " << error.error_message();
194     }
195   } else {
196     LOG(INFO) << "Error on FindFileContainingExtension(" << containing_type
197               << ", " << field_number << ") response type\n\tExpecting: "
198               << ServerReflectionResponse::MessageResponseCase::
199                      kFileDescriptorResponse
200               << "\n\tReceived: " << response.message_response_case();
201   }
202 
203   return cached_db_.FindFileContainingExtension(containing_type, field_number,
204                                                 output);
205 }
206 
FindAllExtensionNumbers(const string & extendee_type,std::vector<int> * output)207 bool ProtoReflectionDescriptorDatabase::FindAllExtensionNumbers(
208     const string& extendee_type, std::vector<int>* output) {
209   if (cached_extension_numbers_.find(extendee_type) !=
210       cached_extension_numbers_.end()) {
211     *output = cached_extension_numbers_[extendee_type];
212     return true;
213   }
214 
215   ServerReflectionRequest request;
216   request.set_all_extension_numbers_of_type(extendee_type);
217   ServerReflectionResponse response;
218 
219   if (!DoOneRequest(request, response)) {
220     return false;
221   }
222 
223   if (response.message_response_case() ==
224       ServerReflectionResponse::MessageResponseCase::
225           kAllExtensionNumbersResponse) {
226     auto number = response.all_extension_numbers_response().extension_number();
227     *output = std::vector<int>(number.begin(), number.end());
228     cached_extension_numbers_[extendee_type] = *output;
229     return true;
230   } else if (response.message_response_case() ==
231              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
232     const ErrorResponse& error = response.error_response();
233     if (error.error_code() == StatusCode::NOT_FOUND) {
234       LOG(INFO) << "NOT_FOUND from server for FindAllExtensionNumbers("
235                 << extendee_type << ")";
236     } else {
237       LOG(INFO) << "Error on FindAllExtensionNumbersExtension(" << extendee_type
238                 << ")\n\tError code: " << error.error_code()
239                 << "\n\tError Message: " << error.error_message();
240     }
241   }
242   return false;
243 }
244 
GetServices(std::vector<std::string> * output)245 bool ProtoReflectionDescriptorDatabase::GetServices(
246     std::vector<std::string>* output) {
247   ServerReflectionRequest request;
248   request.set_list_services("");
249   ServerReflectionResponse response;
250 
251   if (!DoOneRequest(request, response)) {
252     return false;
253   }
254 
255   if (response.message_response_case() ==
256       ServerReflectionResponse::MessageResponseCase::kListServicesResponse) {
257     const ListServiceResponse& ls_response = response.list_services_response();
258     for (int i = 0; i < ls_response.service_size(); ++i) {
259       (*output).push_back(ls_response.service(i).name());
260     }
261     return true;
262   } else if (response.message_response_case() ==
263              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
264     const ErrorResponse& error = response.error_response();
265     LOG(INFO) << "Error on GetServices()\n\tError code: " << error.error_code()
266               << "\n\tError Message: " << error.error_message();
267   } else {
268     LOG(INFO)
269         << "Error on GetServices() response type\n\tExpecting: "
270         << ServerReflectionResponse::MessageResponseCase::kListServicesResponse
271         << "\n\tReceived: " << response.message_response_case();
272   }
273   return false;
274 }
275 
276 protobuf::FileDescriptorProto
ParseFileDescriptorProtoResponse(const std::string & byte_fd_proto)277 ProtoReflectionDescriptorDatabase::ParseFileDescriptorProtoResponse(
278     const std::string& byte_fd_proto) {
279   protobuf::FileDescriptorProto file_desc_proto;
280   file_desc_proto.ParseFromString(byte_fd_proto);
281   return file_desc_proto;
282 }
283 
AddFileFromResponse(const grpc::reflection::v1alpha::FileDescriptorResponse & response)284 void ProtoReflectionDescriptorDatabase::AddFileFromResponse(
285     const grpc::reflection::v1alpha::FileDescriptorResponse& response) {
286   for (int i = 0; i < response.file_descriptor_proto_size(); ++i) {
287     const protobuf::FileDescriptorProto file_proto =
288         ParseFileDescriptorProtoResponse(response.file_descriptor_proto(i));
289     if (known_files_.find(file_proto.name()) == known_files_.end()) {
290       known_files_.insert(file_proto.name());
291       cached_db_.Add(file_proto);
292     }
293   }
294 }
295 
296 std::shared_ptr<ProtoReflectionDescriptorDatabase::ClientStream>
GetStream()297 ProtoReflectionDescriptorDatabase::GetStream() {
298   if (!stream_) {
299     stream_ = stub_->ServerReflectionInfo(&ctx_);
300   }
301   return stream_;
302 }
303 
DoOneRequest(const ServerReflectionRequest & request,ServerReflectionResponse & response)304 bool ProtoReflectionDescriptorDatabase::DoOneRequest(
305     const ServerReflectionRequest& request,
306     ServerReflectionResponse& response) {
307   bool success = false;
308   stream_mutex_.lock();
309   if (GetStream()->Write(request) && GetStream()->Read(&response)) {
310     success = true;
311   }
312   stream_mutex_.unlock();
313   return success;
314 }
315 
316 }  // namespace grpc
317