1 /*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include "test/cpp/util/proto_reflection_descriptor_database.h"
20
21 #include <vector>
22
23 #include <grpc/support/log.h>
24
25 using grpc::reflection::v1alpha::ErrorResponse;
26 using grpc::reflection::v1alpha::ListServiceResponse;
27 using grpc::reflection::v1alpha::ServerReflection;
28 using grpc::reflection::v1alpha::ServerReflectionRequest;
29 using grpc::reflection::v1alpha::ServerReflectionResponse;
30
31 namespace grpc {
32
ProtoReflectionDescriptorDatabase(std::unique_ptr<ServerReflection::Stub> stub)33 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
34 std::unique_ptr<ServerReflection::Stub> stub)
35 : stub_(std::move(stub)) {}
36
ProtoReflectionDescriptorDatabase(const std::shared_ptr<grpc::Channel> & channel)37 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
38 const std::shared_ptr<grpc::Channel>& channel)
39 : stub_(ServerReflection::NewStub(channel)) {}
40
~ProtoReflectionDescriptorDatabase()41 ProtoReflectionDescriptorDatabase::~ProtoReflectionDescriptorDatabase() {
42 if (stream_) {
43 stream_->WritesDone();
44 Status status = stream_->Finish();
45 if (!status.ok()) {
46 if (status.error_code() == StatusCode::UNIMPLEMENTED) {
47 fprintf(stderr,
48 "Reflection request not implemented; "
49 "is the ServerReflection service enabled?\n");
50 } else {
51 fprintf(stderr,
52 "ServerReflectionInfo rpc failed. Error code: %d, message: %s, "
53 "debug info: %s\n",
54 static_cast<int>(status.error_code()),
55 status.error_message().c_str(),
56 ctx_.debug_error_string().c_str());
57 }
58 }
59 }
60 }
61
FindFileByName(const string & filename,protobuf::FileDescriptorProto * output)62 bool ProtoReflectionDescriptorDatabase::FindFileByName(
63 const string& filename, protobuf::FileDescriptorProto* output) {
64 if (cached_db_.FindFileByName(filename, output)) {
65 return true;
66 }
67
68 if (known_files_.find(filename) != known_files_.end()) {
69 return false;
70 }
71
72 ServerReflectionRequest request;
73 request.set_file_by_filename(filename);
74 ServerReflectionResponse response;
75
76 if (!DoOneRequest(request, response)) {
77 return false;
78 }
79
80 if (response.message_response_case() ==
81 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
82 AddFileFromResponse(response.file_descriptor_response());
83 } else if (response.message_response_case() ==
84 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
85 const ErrorResponse& error = response.error_response();
86 if (error.error_code() == StatusCode::NOT_FOUND) {
87 gpr_log(GPR_INFO, "NOT_FOUND from server for FindFileByName(%s)",
88 filename.c_str());
89 } else {
90 gpr_log(GPR_INFO,
91 "Error on FindFileByName(%s)\n\tError code: %d\n"
92 "\tError Message: %s",
93 filename.c_str(), error.error_code(),
94 error.error_message().c_str());
95 }
96 } else {
97 gpr_log(
98 GPR_INFO,
99 "Error on FindFileByName(%s) response type\n"
100 "\tExpecting: %d\n\tReceived: %d",
101 filename.c_str(),
102 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
103 response.message_response_case());
104 }
105
106 return cached_db_.FindFileByName(filename, output);
107 }
108
FindFileContainingSymbol(const string & symbol_name,protobuf::FileDescriptorProto * output)109 bool ProtoReflectionDescriptorDatabase::FindFileContainingSymbol(
110 const string& symbol_name, protobuf::FileDescriptorProto* output) {
111 if (cached_db_.FindFileContainingSymbol(symbol_name, output)) {
112 return true;
113 }
114
115 if (missing_symbols_.find(symbol_name) != missing_symbols_.end()) {
116 return false;
117 }
118
119 ServerReflectionRequest request;
120 request.set_file_containing_symbol(symbol_name);
121 ServerReflectionResponse response;
122
123 if (!DoOneRequest(request, response)) {
124 return false;
125 }
126
127 if (response.message_response_case() ==
128 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
129 AddFileFromResponse(response.file_descriptor_response());
130 } else if (response.message_response_case() ==
131 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
132 const ErrorResponse& error = response.error_response();
133 if (error.error_code() == StatusCode::NOT_FOUND) {
134 missing_symbols_.insert(symbol_name);
135 gpr_log(GPR_INFO,
136 "NOT_FOUND from server for FindFileContainingSymbol(%s)",
137 symbol_name.c_str());
138 } else {
139 gpr_log(GPR_INFO,
140 "Error on FindFileContainingSymbol(%s)\n"
141 "\tError code: %d\n\tError Message: %s",
142 symbol_name.c_str(), error.error_code(),
143 error.error_message().c_str());
144 }
145 } else {
146 gpr_log(
147 GPR_INFO,
148 "Error on FindFileContainingSymbol(%s) response type\n"
149 "\tExpecting: %d\n\tReceived: %d",
150 symbol_name.c_str(),
151 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
152 response.message_response_case());
153 }
154 return cached_db_.FindFileContainingSymbol(symbol_name, output);
155 }
156
FindFileContainingExtension(const string & containing_type,int field_number,protobuf::FileDescriptorProto * output)157 bool ProtoReflectionDescriptorDatabase::FindFileContainingExtension(
158 const string& containing_type, int field_number,
159 protobuf::FileDescriptorProto* output) {
160 if (cached_db_.FindFileContainingExtension(containing_type, field_number,
161 output)) {
162 return true;
163 }
164
165 if (missing_extensions_.find(containing_type) != missing_extensions_.end() &&
166 missing_extensions_[containing_type].find(field_number) !=
167 missing_extensions_[containing_type].end()) {
168 gpr_log(GPR_INFO, "nested map.");
169 return false;
170 }
171
172 ServerReflectionRequest request;
173 request.mutable_file_containing_extension()->set_containing_type(
174 containing_type);
175 request.mutable_file_containing_extension()->set_extension_number(
176 field_number);
177 ServerReflectionResponse response;
178
179 if (!DoOneRequest(request, response)) {
180 return false;
181 }
182
183 if (response.message_response_case() ==
184 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
185 AddFileFromResponse(response.file_descriptor_response());
186 } else if (response.message_response_case() ==
187 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
188 const ErrorResponse& error = response.error_response();
189 if (error.error_code() == StatusCode::NOT_FOUND) {
190 if (missing_extensions_.find(containing_type) ==
191 missing_extensions_.end()) {
192 missing_extensions_[containing_type] = {};
193 }
194 missing_extensions_[containing_type].insert(field_number);
195 gpr_log(GPR_INFO,
196 "NOT_FOUND from server for FindFileContainingExtension(%s, %d)",
197 containing_type.c_str(), field_number);
198 } else {
199 gpr_log(GPR_INFO,
200 "Error on FindFileContainingExtension(%s, %d)\n"
201 "\tError code: %d\n\tError Message: %s",
202 containing_type.c_str(), field_number, error.error_code(),
203 error.error_message().c_str());
204 }
205 } else {
206 gpr_log(
207 GPR_INFO,
208 "Error on FindFileContainingExtension(%s, %d) response type\n"
209 "\tExpecting: %d\n\tReceived: %d",
210 containing_type.c_str(), field_number,
211 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
212 response.message_response_case());
213 }
214
215 return cached_db_.FindFileContainingExtension(containing_type, field_number,
216 output);
217 }
218
FindAllExtensionNumbers(const string & extendee_type,std::vector<int> * output)219 bool ProtoReflectionDescriptorDatabase::FindAllExtensionNumbers(
220 const string& extendee_type, std::vector<int>* output) {
221 if (cached_extension_numbers_.find(extendee_type) !=
222 cached_extension_numbers_.end()) {
223 *output = cached_extension_numbers_[extendee_type];
224 return true;
225 }
226
227 ServerReflectionRequest request;
228 request.set_all_extension_numbers_of_type(extendee_type);
229 ServerReflectionResponse response;
230
231 if (!DoOneRequest(request, response)) {
232 return false;
233 }
234
235 if (response.message_response_case() ==
236 ServerReflectionResponse::MessageResponseCase::
237 kAllExtensionNumbersResponse) {
238 auto number = response.all_extension_numbers_response().extension_number();
239 *output = std::vector<int>(number.begin(), number.end());
240 cached_extension_numbers_[extendee_type] = *output;
241 return true;
242 } else if (response.message_response_case() ==
243 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
244 const ErrorResponse& error = response.error_response();
245 if (error.error_code() == StatusCode::NOT_FOUND) {
246 gpr_log(GPR_INFO, "NOT_FOUND from server for FindAllExtensionNumbers(%s)",
247 extendee_type.c_str());
248 } else {
249 gpr_log(GPR_INFO,
250 "Error on FindAllExtensionNumbersExtension(%s)\n"
251 "\tError code: %d\n\tError Message: %s",
252 extendee_type.c_str(), error.error_code(),
253 error.error_message().c_str());
254 }
255 }
256 return false;
257 }
258
GetServices(std::vector<std::string> * output)259 bool ProtoReflectionDescriptorDatabase::GetServices(
260 std::vector<std::string>* output) {
261 ServerReflectionRequest request;
262 request.set_list_services("");
263 ServerReflectionResponse response;
264
265 if (!DoOneRequest(request, response)) {
266 return false;
267 }
268
269 if (response.message_response_case() ==
270 ServerReflectionResponse::MessageResponseCase::kListServicesResponse) {
271 const ListServiceResponse& ls_response = response.list_services_response();
272 for (int i = 0; i < ls_response.service_size(); ++i) {
273 (*output).push_back(ls_response.service(i).name());
274 }
275 return true;
276 } else if (response.message_response_case() ==
277 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
278 const ErrorResponse& error = response.error_response();
279 gpr_log(GPR_INFO,
280 "Error on GetServices()\n\tError code: %d\n"
281 "\tError Message: %s",
282 error.error_code(), error.error_message().c_str());
283 } else {
284 gpr_log(
285 GPR_INFO,
286 "Error on GetServices() response type\n\tExpecting: %d\n\tReceived: %d",
287 ServerReflectionResponse::MessageResponseCase::kListServicesResponse,
288 response.message_response_case());
289 }
290 return false;
291 }
292
293 const protobuf::FileDescriptorProto
ParseFileDescriptorProtoResponse(const std::string & byte_fd_proto)294 ProtoReflectionDescriptorDatabase::ParseFileDescriptorProtoResponse(
295 const std::string& byte_fd_proto) {
296 protobuf::FileDescriptorProto file_desc_proto;
297 file_desc_proto.ParseFromString(byte_fd_proto);
298 return file_desc_proto;
299 }
300
AddFileFromResponse(const grpc::reflection::v1alpha::FileDescriptorResponse & response)301 void ProtoReflectionDescriptorDatabase::AddFileFromResponse(
302 const grpc::reflection::v1alpha::FileDescriptorResponse& response) {
303 for (int i = 0; i < response.file_descriptor_proto_size(); ++i) {
304 const protobuf::FileDescriptorProto file_proto =
305 ParseFileDescriptorProtoResponse(response.file_descriptor_proto(i));
306 if (known_files_.find(file_proto.name()) == known_files_.end()) {
307 known_files_.insert(file_proto.name());
308 cached_db_.Add(file_proto);
309 }
310 }
311 }
312
313 const std::shared_ptr<ProtoReflectionDescriptorDatabase::ClientStream>
GetStream()314 ProtoReflectionDescriptorDatabase::GetStream() {
315 if (!stream_) {
316 stream_ = stub_->ServerReflectionInfo(&ctx_);
317 }
318 return stream_;
319 }
320
DoOneRequest(const ServerReflectionRequest & request,ServerReflectionResponse & response)321 bool ProtoReflectionDescriptorDatabase::DoOneRequest(
322 const ServerReflectionRequest& request,
323 ServerReflectionResponse& response) {
324 bool success = false;
325 stream_mutex_.lock();
326 if (GetStream()->Write(request) && GetStream()->Read(&response)) {
327 success = true;
328 }
329 stream_mutex_.unlock();
330 return success;
331 }
332
333 } // namespace grpc
334