1 //
2 //
3 // Copyright 2016 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "test/cpp/util/proto_reflection_descriptor_database.h"
20
21 #include <vector>
22
23 #include "absl/log/log.h"
24 #include "src/core/util/crash.h"
25
26 using grpc::reflection::v1alpha::ErrorResponse;
27 using grpc::reflection::v1alpha::ListServiceResponse;
28 using grpc::reflection::v1alpha::ServerReflection;
29 using grpc::reflection::v1alpha::ServerReflectionRequest;
30 using grpc::reflection::v1alpha::ServerReflectionResponse;
31
32 namespace grpc {
33
ProtoReflectionDescriptorDatabase(std::unique_ptr<ServerReflection::Stub> stub)34 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
35 std::unique_ptr<ServerReflection::Stub> stub)
36 : stub_(std::move(stub)) {}
37
ProtoReflectionDescriptorDatabase(const std::shared_ptr<grpc::ChannelInterface> & channel)38 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
39 const std::shared_ptr<grpc::ChannelInterface>& channel)
40 : stub_(ServerReflection::NewStub(channel)) {}
41
~ProtoReflectionDescriptorDatabase()42 ProtoReflectionDescriptorDatabase::~ProtoReflectionDescriptorDatabase() {
43 if (stream_) {
44 stream_->WritesDone();
45 Status status = stream_->Finish();
46 if (!status.ok()) {
47 if (status.error_code() == StatusCode::UNIMPLEMENTED) {
48 fprintf(stderr,
49 "Reflection request not implemented; "
50 "is the ServerReflection service enabled?\n");
51 } else {
52 fprintf(stderr,
53 "ServerReflectionInfo rpc failed. Error code: %d, message: %s, "
54 "debug info: %s\n",
55 static_cast<int>(status.error_code()),
56 status.error_message().c_str(),
57 ctx_.debug_error_string().c_str());
58 }
59 }
60 }
61 }
62
FindFileByName(const string & filename,protobuf::FileDescriptorProto * output)63 bool ProtoReflectionDescriptorDatabase::FindFileByName(
64 const string& filename, protobuf::FileDescriptorProto* output) {
65 if (cached_db_.FindFileByName(filename, output)) {
66 return true;
67 }
68
69 if (known_files_.find(filename) != known_files_.end()) {
70 return false;
71 }
72
73 ServerReflectionRequest request;
74 request.set_file_by_filename(filename);
75 ServerReflectionResponse response;
76
77 if (!DoOneRequest(request, response)) {
78 return false;
79 }
80
81 if (response.message_response_case() ==
82 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
83 AddFileFromResponse(response.file_descriptor_response());
84 } else if (response.message_response_case() ==
85 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
86 const ErrorResponse& error = response.error_response();
87 if (error.error_code() == StatusCode::NOT_FOUND) {
88 LOG(INFO) << "NOT_FOUND from server for FindFileByName(" << filename
89 << ")";
90 } else {
91 LOG(INFO) << "Error on FindFileByName(" << filename
92 << ")\n\tError code: " << error.error_code()
93 << "\n\tError Message: " << error.error_message();
94 }
95 } else {
96 LOG(INFO) << "Error on FindFileByName(" << filename
97 << ") response type\n\tExpecting: "
98 << ServerReflectionResponse::MessageResponseCase::
99 kFileDescriptorResponse
100 << "\n\tReceived: " << response.message_response_case();
101 }
102
103 return cached_db_.FindFileByName(filename, output);
104 }
105
FindFileContainingSymbol(const string & symbol_name,protobuf::FileDescriptorProto * output)106 bool ProtoReflectionDescriptorDatabase::FindFileContainingSymbol(
107 const string& symbol_name, protobuf::FileDescriptorProto* output) {
108 if (cached_db_.FindFileContainingSymbol(symbol_name, output)) {
109 return true;
110 }
111
112 if (missing_symbols_.find(symbol_name) != missing_symbols_.end()) {
113 return false;
114 }
115
116 ServerReflectionRequest request;
117 request.set_file_containing_symbol(symbol_name);
118 ServerReflectionResponse response;
119
120 if (!DoOneRequest(request, response)) {
121 return false;
122 }
123
124 if (response.message_response_case() ==
125 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
126 AddFileFromResponse(response.file_descriptor_response());
127 } else if (response.message_response_case() ==
128 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
129 const ErrorResponse& error = response.error_response();
130 if (error.error_code() == StatusCode::NOT_FOUND) {
131 missing_symbols_.insert(symbol_name);
132 LOG(INFO) << "NOT_FOUND from server for FindFileContainingSymbol("
133 << symbol_name << ")";
134 } else {
135 LOG(INFO) << "Error on FindFileContainingSymbol(" << symbol_name
136 << ")\n\tError code: " << error.error_code()
137 << "\n\tError Message: " << error.error_message();
138 }
139 } else {
140 LOG(INFO) << "Error on FindFileContainingSymbol(" << symbol_name
141 << ") response type\n\tExpecting: "
142 << ServerReflectionResponse::MessageResponseCase::
143 kFileDescriptorResponse
144 << "\n\tReceived: " << response.message_response_case();
145 }
146 return cached_db_.FindFileContainingSymbol(symbol_name, output);
147 }
148
FindFileContainingExtension(const string & containing_type,int field_number,protobuf::FileDescriptorProto * output)149 bool ProtoReflectionDescriptorDatabase::FindFileContainingExtension(
150 const string& containing_type, int field_number,
151 protobuf::FileDescriptorProto* output) {
152 if (cached_db_.FindFileContainingExtension(containing_type, field_number,
153 output)) {
154 return true;
155 }
156
157 if (missing_extensions_.find(containing_type) != missing_extensions_.end() &&
158 missing_extensions_[containing_type].find(field_number) !=
159 missing_extensions_[containing_type].end()) {
160 LOG(INFO) << "nested map.";
161 return false;
162 }
163
164 ServerReflectionRequest request;
165 request.mutable_file_containing_extension()->set_containing_type(
166 containing_type);
167 request.mutable_file_containing_extension()->set_extension_number(
168 field_number);
169 ServerReflectionResponse response;
170
171 if (!DoOneRequest(request, response)) {
172 return false;
173 }
174
175 if (response.message_response_case() ==
176 ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
177 AddFileFromResponse(response.file_descriptor_response());
178 } else if (response.message_response_case() ==
179 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
180 const ErrorResponse& error = response.error_response();
181 if (error.error_code() == StatusCode::NOT_FOUND) {
182 if (missing_extensions_.find(containing_type) ==
183 missing_extensions_.end()) {
184 missing_extensions_[containing_type] = {};
185 }
186 missing_extensions_[containing_type].insert(field_number);
187 LOG(INFO) << "NOT_FOUND from server for FindFileContainingExtension("
188 << containing_type << ", " << field_number << ")";
189 } else {
190 LOG(INFO) << "Error on FindFileContainingExtension(" << containing_type
191 << ", " << field_number
192 << ")\n\tError code: " << error.error_code()
193 << "\n\tError Message: " << error.error_message();
194 }
195 } else {
196 LOG(INFO) << "Error on FindFileContainingExtension(" << containing_type
197 << ", " << field_number << ") response type\n\tExpecting: "
198 << ServerReflectionResponse::MessageResponseCase::
199 kFileDescriptorResponse
200 << "\n\tReceived: " << response.message_response_case();
201 }
202
203 return cached_db_.FindFileContainingExtension(containing_type, field_number,
204 output);
205 }
206
FindAllExtensionNumbers(const string & extendee_type,std::vector<int> * output)207 bool ProtoReflectionDescriptorDatabase::FindAllExtensionNumbers(
208 const string& extendee_type, std::vector<int>* output) {
209 if (cached_extension_numbers_.find(extendee_type) !=
210 cached_extension_numbers_.end()) {
211 *output = cached_extension_numbers_[extendee_type];
212 return true;
213 }
214
215 ServerReflectionRequest request;
216 request.set_all_extension_numbers_of_type(extendee_type);
217 ServerReflectionResponse response;
218
219 if (!DoOneRequest(request, response)) {
220 return false;
221 }
222
223 if (response.message_response_case() ==
224 ServerReflectionResponse::MessageResponseCase::
225 kAllExtensionNumbersResponse) {
226 auto number = response.all_extension_numbers_response().extension_number();
227 *output = std::vector<int>(number.begin(), number.end());
228 cached_extension_numbers_[extendee_type] = *output;
229 return true;
230 } else if (response.message_response_case() ==
231 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
232 const ErrorResponse& error = response.error_response();
233 if (error.error_code() == StatusCode::NOT_FOUND) {
234 LOG(INFO) << "NOT_FOUND from server for FindAllExtensionNumbers("
235 << extendee_type << ")";
236 } else {
237 LOG(INFO) << "Error on FindAllExtensionNumbersExtension(" << extendee_type
238 << ")\n\tError code: " << error.error_code()
239 << "\n\tError Message: " << error.error_message();
240 }
241 }
242 return false;
243 }
244
GetServices(std::vector<std::string> * output)245 bool ProtoReflectionDescriptorDatabase::GetServices(
246 std::vector<std::string>* output) {
247 ServerReflectionRequest request;
248 request.set_list_services("");
249 ServerReflectionResponse response;
250
251 if (!DoOneRequest(request, response)) {
252 return false;
253 }
254
255 if (response.message_response_case() ==
256 ServerReflectionResponse::MessageResponseCase::kListServicesResponse) {
257 const ListServiceResponse& ls_response = response.list_services_response();
258 for (int i = 0; i < ls_response.service_size(); ++i) {
259 (*output).push_back(ls_response.service(i).name());
260 }
261 return true;
262 } else if (response.message_response_case() ==
263 ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
264 const ErrorResponse& error = response.error_response();
265 LOG(INFO) << "Error on GetServices()\n\tError code: " << error.error_code()
266 << "\n\tError Message: " << error.error_message();
267 } else {
268 LOG(INFO)
269 << "Error on GetServices() response type\n\tExpecting: "
270 << ServerReflectionResponse::MessageResponseCase::kListServicesResponse
271 << "\n\tReceived: " << response.message_response_case();
272 }
273 return false;
274 }
275
276 protobuf::FileDescriptorProto
ParseFileDescriptorProtoResponse(const std::string & byte_fd_proto)277 ProtoReflectionDescriptorDatabase::ParseFileDescriptorProtoResponse(
278 const std::string& byte_fd_proto) {
279 protobuf::FileDescriptorProto file_desc_proto;
280 file_desc_proto.ParseFromString(byte_fd_proto);
281 return file_desc_proto;
282 }
283
AddFileFromResponse(const grpc::reflection::v1alpha::FileDescriptorResponse & response)284 void ProtoReflectionDescriptorDatabase::AddFileFromResponse(
285 const grpc::reflection::v1alpha::FileDescriptorResponse& response) {
286 for (int i = 0; i < response.file_descriptor_proto_size(); ++i) {
287 const protobuf::FileDescriptorProto file_proto =
288 ParseFileDescriptorProtoResponse(response.file_descriptor_proto(i));
289 if (known_files_.find(file_proto.name()) == known_files_.end()) {
290 known_files_.insert(file_proto.name());
291 cached_db_.Add(file_proto);
292 }
293 }
294 }
295
296 std::shared_ptr<ProtoReflectionDescriptorDatabase::ClientStream>
GetStream()297 ProtoReflectionDescriptorDatabase::GetStream() {
298 if (!stream_) {
299 stream_ = stub_->ServerReflectionInfo(&ctx_);
300 }
301 return stream_;
302 }
303
DoOneRequest(const ServerReflectionRequest & request,ServerReflectionResponse & response)304 bool ProtoReflectionDescriptorDatabase::DoOneRequest(
305 const ServerReflectionRequest& request,
306 ServerReflectionResponse& response) {
307 bool success = false;
308 stream_mutex_.lock();
309 if (GetStream()->Write(request) && GetStream()->Read(&response)) {
310 success = true;
311 }
312 stream_mutex_.unlock();
313 return success;
314 }
315
316 } // namespace grpc
317