1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #ifndef GRPC_TEST_CPP_END2END_INTERCEPTORS_UTIL_H
20 #define GRPC_TEST_CPP_END2END_INTERCEPTORS_UTIL_H
21
22 #include <grpcpp/channel.h>
23 #include <gtest/gtest.h>
24
25 #include <condition_variable>
26
27 #include "absl/log/check.h"
28 #include "absl/strings/str_format.h"
29 #include "src/core/util/crash.h"
30 #include "src/proto/grpc/testing/echo.grpc.pb.h"
31 #include "test/cpp/util/string_ref_helper.h"
32
33 namespace grpc {
34 namespace testing {
35 // This interceptor does nothing. Just keeps a global count on the number of
36 // times it was invoked.
37 class PhonyInterceptor : public experimental::Interceptor {
38 public:
PhonyInterceptor()39 PhonyInterceptor() {}
40
Intercept(experimental::InterceptorBatchMethods * methods)41 void Intercept(experimental::InterceptorBatchMethods* methods) override {
42 if (methods->QueryInterceptionHookPoint(
43 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
44 num_times_run_++;
45 } else if (methods->QueryInterceptionHookPoint(
46 experimental::InterceptionHookPoints::
47 POST_RECV_INITIAL_METADATA)) {
48 num_times_run_reverse_++;
49 } else if (methods->QueryInterceptionHookPoint(
50 experimental::InterceptionHookPoints::PRE_SEND_CANCEL)) {
51 num_times_cancel_++;
52 }
53 methods->Proceed();
54 }
55
Reset()56 static void Reset() {
57 num_times_run_.store(0);
58 num_times_run_reverse_.store(0);
59 num_times_cancel_.store(0);
60 }
61
GetNumTimesRun()62 static int GetNumTimesRun() {
63 EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load());
64 return num_times_run_.load();
65 }
66
GetNumTimesCancel()67 static int GetNumTimesCancel() { return num_times_cancel_.load(); }
68
69 private:
70 static std::atomic<int> num_times_run_;
71 static std::atomic<int> num_times_run_reverse_;
72 static std::atomic<int> num_times_cancel_;
73 };
74
75 class PhonyInterceptorFactory
76 : public experimental::ClientInterceptorFactoryInterface,
77 public experimental::ServerInterceptorFactoryInterface {
78 public:
CreateClientInterceptor(experimental::ClientRpcInfo *)79 experimental::Interceptor* CreateClientInterceptor(
80 experimental::ClientRpcInfo* /*info*/) override {
81 return new PhonyInterceptor();
82 }
83
CreateServerInterceptor(experimental::ServerRpcInfo *)84 experimental::Interceptor* CreateServerInterceptor(
85 experimental::ServerRpcInfo* /*info*/) override {
86 return new PhonyInterceptor();
87 }
88 };
89
90 // This interceptor can be used to test the interception mechanism.
91 class TestInterceptor : public experimental::Interceptor {
92 public:
TestInterceptor(const std::string & method,const char * suffix_for_stats,experimental::ClientRpcInfo * info)93 TestInterceptor(const std::string& method, const char* suffix_for_stats,
94 experimental::ClientRpcInfo* info) {
95 EXPECT_EQ(info->method(), method);
96
97 if (suffix_for_stats == nullptr || info->suffix_for_stats() == nullptr) {
98 EXPECT_EQ(info->suffix_for_stats(), suffix_for_stats);
99 } else {
100 EXPECT_EQ(strcmp(info->suffix_for_stats(), suffix_for_stats), 0);
101 }
102 }
103
Intercept(experimental::InterceptorBatchMethods * methods)104 void Intercept(experimental::InterceptorBatchMethods* methods) override {
105 methods->Proceed();
106 }
107 };
108
109 class TestInterceptorFactory
110 : public experimental::ClientInterceptorFactoryInterface {
111 public:
TestInterceptorFactory(const std::string & method,const char * suffix_for_stats)112 TestInterceptorFactory(const std::string& method,
113 const char* suffix_for_stats)
114 : method_(method), suffix_for_stats_(suffix_for_stats) {}
115
CreateClientInterceptor(experimental::ClientRpcInfo * info)116 experimental::Interceptor* CreateClientInterceptor(
117 experimental::ClientRpcInfo* info) override {
118 return new TestInterceptor(method_, suffix_for_stats_, info);
119 }
120
121 private:
122 std::string method_;
123 const char* suffix_for_stats_;
124 };
125
126 // This interceptor factory returns nullptr on interceptor creation
127 class NullInterceptorFactory
128 : public experimental::ClientInterceptorFactoryInterface,
129 public experimental::ServerInterceptorFactoryInterface {
130 public:
CreateClientInterceptor(experimental::ClientRpcInfo *)131 experimental::Interceptor* CreateClientInterceptor(
132 experimental::ClientRpcInfo* /*info*/) override {
133 return nullptr;
134 }
135
CreateServerInterceptor(experimental::ServerRpcInfo *)136 experimental::Interceptor* CreateServerInterceptor(
137 experimental::ServerRpcInfo* /*info*/) override {
138 return nullptr;
139 }
140 };
141
142 class EchoTestServiceStreamingImpl : public EchoTestService::Service {
143 public:
~EchoTestServiceStreamingImpl()144 ~EchoTestServiceStreamingImpl() override {}
145
Echo(ServerContext * context,const EchoRequest * request,EchoResponse * response)146 Status Echo(ServerContext* context, const EchoRequest* request,
147 EchoResponse* response) override {
148 auto client_metadata = context->client_metadata();
149 for (const auto& pair : client_metadata) {
150 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
151 }
152 response->set_message(request->message());
153 return Status::OK;
154 }
155
BidiStream(ServerContext * context,grpc::ServerReaderWriter<EchoResponse,EchoRequest> * stream)156 Status BidiStream(
157 ServerContext* context,
158 grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
159 EchoRequest req;
160 EchoResponse resp;
161 auto client_metadata = context->client_metadata();
162 for (const auto& pair : client_metadata) {
163 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
164 }
165
166 while (stream->Read(&req)) {
167 resp.set_message(req.message());
168 EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions()));
169 }
170 return Status::OK;
171 }
172
RequestStream(ServerContext * context,ServerReader<EchoRequest> * reader,EchoResponse * resp)173 Status RequestStream(ServerContext* context,
174 ServerReader<EchoRequest>* reader,
175 EchoResponse* resp) override {
176 auto client_metadata = context->client_metadata();
177 for (const auto& pair : client_metadata) {
178 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
179 }
180
181 EchoRequest req;
182 string response_str;
183 while (reader->Read(&req)) {
184 response_str += req.message();
185 }
186 resp->set_message(response_str);
187 return Status::OK;
188 }
189
ResponseStream(ServerContext * context,const EchoRequest * req,ServerWriter<EchoResponse> * writer)190 Status ResponseStream(ServerContext* context, const EchoRequest* req,
191 ServerWriter<EchoResponse>* writer) override {
192 auto client_metadata = context->client_metadata();
193 for (const auto& pair : client_metadata) {
194 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
195 }
196
197 EchoResponse resp;
198 resp.set_message(req->message());
199 for (int i = 0; i < 10; i++) {
200 EXPECT_TRUE(writer->Write(resp));
201 }
202 return Status::OK;
203 }
204 };
205
206 constexpr int kNumStreamingMessages = 10;
207
208 void MakeCall(const std::shared_ptr<Channel>& channel,
209 const StubOptions& options = StubOptions());
210
211 void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel);
212
213 void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel);
214
215 void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel);
216
217 void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel);
218
219 void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel);
220
221 void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel);
222
223 void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel);
224
225 void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
226
227 bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
228 const string& key, const string& value);
229
230 bool CheckMetadata(const std::multimap<std::string, std::string>& map,
231 const string& key, const string& value);
232
233 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
234 CreatePhonyClientInterceptors();
235
tag(int i)236 inline void* tag(int i) { return reinterpret_cast<void*>(i); }
detag(void * p)237 inline int detag(void* p) {
238 return static_cast<int>(reinterpret_cast<intptr_t>(p));
239 }
240
241 class Verifier {
242 public:
Verifier()243 Verifier() : lambda_run_(false) {}
244 // Expect sets the expected ok value for a specific tag
Expect(int i,bool expect_ok)245 Verifier& Expect(int i, bool expect_ok) {
246 return ExpectUnless(i, expect_ok, false);
247 }
248 // ExpectUnless sets the expected ok value for a specific tag
249 // unless the tag was already marked seen (as a result of ExpectMaybe)
ExpectUnless(int i,bool expect_ok,bool seen)250 Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
251 if (!seen) {
252 expectations_[tag(i)] = expect_ok;
253 }
254 return *this;
255 }
256 // ExpectMaybe sets the expected ok value for a specific tag, but does not
257 // require it to appear
258 // If it does, sets *seen to true
ExpectMaybe(int i,bool expect_ok,bool * seen)259 Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
260 if (!*seen) {
261 maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
262 }
263 return *this;
264 }
265
266 // Next waits for 1 async tag to complete, checks its
267 // expectations, and returns the tag
Next(CompletionQueue * cq,bool ignore_ok)268 int Next(CompletionQueue* cq, bool ignore_ok) {
269 bool ok;
270 void* got_tag;
271 EXPECT_TRUE(cq->Next(&got_tag, &ok));
272 GotTag(got_tag, ok, ignore_ok);
273 return detag(got_tag);
274 }
275
276 template <typename T>
DoOnceThenAsyncNext(CompletionQueue * cq,void ** got_tag,bool * ok,T deadline,std::function<void (void)> lambda)277 CompletionQueue::NextStatus DoOnceThenAsyncNext(
278 CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
279 std::function<void(void)> lambda) {
280 if (lambda_run_) {
281 return cq->AsyncNext(got_tag, ok, deadline);
282 } else {
283 lambda_run_ = true;
284 return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
285 }
286 }
287
288 // Verify keeps calling Next until all currently set
289 // expected tags are complete
Verify(CompletionQueue * cq)290 void Verify(CompletionQueue* cq) { Verify(cq, false); }
291
292 // This version of Verify allows optionally ignoring the
293 // outcome of the expectation
Verify(CompletionQueue * cq,bool ignore_ok)294 void Verify(CompletionQueue* cq, bool ignore_ok) {
295 CHECK(!expectations_.empty() || !maybe_expectations_.empty());
296 while (!expectations_.empty()) {
297 Next(cq, ignore_ok);
298 }
299 }
300
301 // This version of Verify stops after a certain deadline, and uses the
302 // DoThenAsyncNext API
303 // to call the lambda
Verify(CompletionQueue * cq,std::chrono::system_clock::time_point deadline,const std::function<void (void)> & lambda)304 void Verify(CompletionQueue* cq,
305 std::chrono::system_clock::time_point deadline,
306 const std::function<void(void)>& lambda) {
307 if (expectations_.empty()) {
308 bool ok;
309 void* got_tag;
310 EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
311 CompletionQueue::TIMEOUT);
312 } else {
313 while (!expectations_.empty()) {
314 bool ok;
315 void* got_tag;
316 EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
317 CompletionQueue::GOT_EVENT);
318 GotTag(got_tag, ok, false);
319 }
320 }
321 }
322
323 private:
GotTag(void * got_tag,bool ok,bool ignore_ok)324 void GotTag(void* got_tag, bool ok, bool ignore_ok) {
325 auto it = expectations_.find(got_tag);
326 if (it != expectations_.end()) {
327 if (!ignore_ok) {
328 EXPECT_EQ(it->second, ok);
329 }
330 expectations_.erase(it);
331 } else {
332 auto it2 = maybe_expectations_.find(got_tag);
333 if (it2 != maybe_expectations_.end()) {
334 if (it2->second.seen != nullptr) {
335 EXPECT_FALSE(*it2->second.seen);
336 *it2->second.seen = true;
337 }
338 if (!ignore_ok) {
339 EXPECT_EQ(it2->second.ok, ok);
340 }
341 } else {
342 grpc_core::Crash(absl::StrFormat("Unexpected tag: %p", got_tag));
343 }
344 }
345 }
346
347 struct MaybeExpect {
348 bool ok;
349 bool* seen;
350 };
351
352 std::map<void*, bool> expectations_;
353 std::map<void*, MaybeExpect> maybe_expectations_;
354 bool lambda_run_;
355 };
356
357 } // namespace testing
358 } // namespace grpc
359
360 #endif // GRPC_TEST_CPP_END2END_INTERCEPTORS_UTIL_H
361