1 /*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include <memory>
20 #include <vector>
21
22 #include <grpcpp/channel.h>
23 #include <grpcpp/client_context.h>
24 #include <grpcpp/create_channel.h>
25 #include <grpcpp/generic/generic_stub.h>
26 #include <grpcpp/impl/codegen/proto_utils.h>
27 #include <grpcpp/server.h>
28 #include <grpcpp/server_builder.h>
29 #include <grpcpp/server_context.h>
30 #include <grpcpp/support/client_interceptor.h>
31
32 #include "src/proto/grpc/testing/echo.grpc.pb.h"
33 #include "test/core/util/port.h"
34 #include "test/core/util/test_config.h"
35 #include "test/cpp/end2end/interceptors_util.h"
36 #include "test/cpp/end2end/test_service_impl.h"
37 #include "test/cpp/util/byte_buffer_proto_helper.h"
38 #include "test/cpp/util/string_ref_helper.h"
39
40 #include <gtest/gtest.h>
41
42 namespace grpc {
43 namespace testing {
44 namespace {
45
46 enum class RPCType {
47 kSyncUnary,
48 kSyncClientStreaming,
49 kSyncServerStreaming,
50 kSyncBidiStreaming,
51 kAsyncCQUnary,
52 kAsyncCQClientStreaming,
53 kAsyncCQServerStreaming,
54 kAsyncCQBidiStreaming,
55 };
56
57 /* Hijacks Echo RPC and fills in the expected values */
58 class HijackingInterceptor : public experimental::Interceptor {
59 public:
HijackingInterceptor(experimental::ClientRpcInfo * info)60 HijackingInterceptor(experimental::ClientRpcInfo* info) {
61 info_ = info;
62 // Make sure it is the right method
63 EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
64 EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
65 }
66
Intercept(experimental::InterceptorBatchMethods * methods)67 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
68 bool hijack = false;
69 if (methods->QueryInterceptionHookPoint(
70 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
71 auto* map = methods->GetSendInitialMetadata();
72 // Check that we can see the test metadata
73 ASSERT_EQ(map->size(), static_cast<unsigned>(1));
74 auto iterator = map->begin();
75 EXPECT_EQ("testkey", iterator->first);
76 EXPECT_EQ("testvalue", iterator->second);
77 hijack = true;
78 }
79 if (methods->QueryInterceptionHookPoint(
80 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
81 EchoRequest req;
82 auto* buffer = methods->GetSerializedSendMessage();
83 auto copied_buffer = *buffer;
84 EXPECT_TRUE(
85 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
86 .ok());
87 EXPECT_EQ(req.message(), "Hello");
88 }
89 if (methods->QueryInterceptionHookPoint(
90 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
91 // Got nothing to do here for now
92 }
93 if (methods->QueryInterceptionHookPoint(
94 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
95 auto* map = methods->GetRecvInitialMetadata();
96 // Got nothing better to do here for now
97 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
98 }
99 if (methods->QueryInterceptionHookPoint(
100 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
101 EchoResponse* resp =
102 static_cast<EchoResponse*>(methods->GetRecvMessage());
103 // Check that we got the hijacked message, and re-insert the expected
104 // message
105 EXPECT_EQ(resp->message(), "Hello1");
106 resp->set_message("Hello");
107 }
108 if (methods->QueryInterceptionHookPoint(
109 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
110 auto* map = methods->GetRecvTrailingMetadata();
111 bool found = false;
112 // Check that we received the metadata as an echo
113 for (const auto& pair : *map) {
114 found = pair.first.starts_with("testkey") &&
115 pair.second.starts_with("testvalue");
116 if (found) break;
117 }
118 EXPECT_EQ(found, true);
119 auto* status = methods->GetRecvStatus();
120 EXPECT_EQ(status->ok(), true);
121 }
122 if (methods->QueryInterceptionHookPoint(
123 experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
124 auto* map = methods->GetRecvInitialMetadata();
125 // Got nothing better to do here at the moment
126 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
127 }
128 if (methods->QueryInterceptionHookPoint(
129 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
130 // Insert a different message than expected
131 EchoResponse* resp =
132 static_cast<EchoResponse*>(methods->GetRecvMessage());
133 resp->set_message("Hello1");
134 }
135 if (methods->QueryInterceptionHookPoint(
136 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
137 auto* map = methods->GetRecvTrailingMetadata();
138 // insert the metadata that we want
139 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
140 map->insert(std::make_pair("testkey", "testvalue"));
141 auto* status = methods->GetRecvStatus();
142 *status = Status(StatusCode::OK, "");
143 }
144 if (hijack) {
145 methods->Hijack();
146 } else {
147 methods->Proceed();
148 }
149 }
150
151 private:
152 experimental::ClientRpcInfo* info_;
153 };
154
155 class HijackingInterceptorFactory
156 : public experimental::ClientInterceptorFactoryInterface {
157 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)158 virtual experimental::Interceptor* CreateClientInterceptor(
159 experimental::ClientRpcInfo* info) override {
160 return new HijackingInterceptor(info);
161 }
162 };
163
164 class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
165 public:
HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo * info)166 HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) {
167 info_ = info;
168 // Make sure it is the right method
169 EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
170 }
171
Intercept(experimental::InterceptorBatchMethods * methods)172 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
173 if (methods->QueryInterceptionHookPoint(
174 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
175 auto* map = methods->GetSendInitialMetadata();
176 // Check that we can see the test metadata
177 ASSERT_EQ(map->size(), static_cast<unsigned>(1));
178 auto iterator = map->begin();
179 EXPECT_EQ("testkey", iterator->first);
180 EXPECT_EQ("testvalue", iterator->second);
181 // Make a copy of the map
182 metadata_map_ = *map;
183 }
184 if (methods->QueryInterceptionHookPoint(
185 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
186 EchoRequest req;
187 auto* buffer = methods->GetSerializedSendMessage();
188 auto copied_buffer = *buffer;
189 EXPECT_TRUE(
190 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
191 .ok());
192 EXPECT_EQ(req.message(), "Hello");
193 req_ = req;
194 stub_ = grpc::testing::EchoTestService::NewStub(
195 methods->GetInterceptedChannel());
196 ctx_.AddMetadata(metadata_map_.begin()->first,
197 metadata_map_.begin()->second);
198 stub_->experimental_async()->Echo(&ctx_, &req_, &resp_,
199 [this, methods](Status s) {
200 EXPECT_EQ(s.ok(), true);
201 EXPECT_EQ(resp_.message(), "Hello");
202 methods->Hijack();
203 });
204 // This is a Unary RPC and we have got nothing interesting to do in the
205 // PRE_SEND_CLOSE interception hook point for this interceptor, so let's
206 // return here. (We do not want to call methods->Proceed(). When the new
207 // RPC returns, we will call methods->Hijack() instead.)
208 return;
209 }
210 if (methods->QueryInterceptionHookPoint(
211 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
212 // Got nothing to do here for now
213 }
214 if (methods->QueryInterceptionHookPoint(
215 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
216 auto* map = methods->GetRecvInitialMetadata();
217 // Got nothing better to do here for now
218 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
219 }
220 if (methods->QueryInterceptionHookPoint(
221 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
222 EchoResponse* resp =
223 static_cast<EchoResponse*>(methods->GetRecvMessage());
224 // Check that we got the hijacked message, and re-insert the expected
225 // message
226 EXPECT_EQ(resp->message(), "Hello");
227 }
228 if (methods->QueryInterceptionHookPoint(
229 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
230 auto* map = methods->GetRecvTrailingMetadata();
231 bool found = false;
232 // Check that we received the metadata as an echo
233 for (const auto& pair : *map) {
234 found = pair.first.starts_with("testkey") &&
235 pair.second.starts_with("testvalue");
236 if (found) break;
237 }
238 EXPECT_EQ(found, true);
239 auto* status = methods->GetRecvStatus();
240 EXPECT_EQ(status->ok(), true);
241 }
242 if (methods->QueryInterceptionHookPoint(
243 experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
244 auto* map = methods->GetRecvInitialMetadata();
245 // Got nothing better to do here at the moment
246 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
247 }
248 if (methods->QueryInterceptionHookPoint(
249 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
250 // Insert a different message than expected
251 EchoResponse* resp =
252 static_cast<EchoResponse*>(methods->GetRecvMessage());
253 resp->set_message(resp_.message());
254 }
255 if (methods->QueryInterceptionHookPoint(
256 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
257 auto* map = methods->GetRecvTrailingMetadata();
258 // insert the metadata that we want
259 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
260 map->insert(std::make_pair("testkey", "testvalue"));
261 auto* status = methods->GetRecvStatus();
262 *status = Status(StatusCode::OK, "");
263 }
264
265 methods->Proceed();
266 }
267
268 private:
269 experimental::ClientRpcInfo* info_;
270 std::multimap<std::string, std::string> metadata_map_;
271 ClientContext ctx_;
272 EchoRequest req_;
273 EchoResponse resp_;
274 std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
275 };
276
277 class HijackingInterceptorMakesAnotherCallFactory
278 : public experimental::ClientInterceptorFactoryInterface {
279 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)280 virtual experimental::Interceptor* CreateClientInterceptor(
281 experimental::ClientRpcInfo* info) override {
282 return new HijackingInterceptorMakesAnotherCall(info);
283 }
284 };
285
286 class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
287 public:
BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)288 BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
289 info_ = info;
290 }
291
Intercept(experimental::InterceptorBatchMethods * methods)292 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
293 bool hijack = false;
294 if (methods->QueryInterceptionHookPoint(
295 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
296 CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
297 hijack = true;
298 }
299 if (methods->QueryInterceptionHookPoint(
300 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
301 EchoRequest req;
302 auto* buffer = methods->GetSerializedSendMessage();
303 auto copied_buffer = *buffer;
304 EXPECT_TRUE(
305 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
306 .ok());
307 EXPECT_EQ(req.message().find("Hello"), 0u);
308 msg = req.message();
309 }
310 if (methods->QueryInterceptionHookPoint(
311 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
312 // Got nothing to do here for now
313 }
314 if (methods->QueryInterceptionHookPoint(
315 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
316 CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
317 "testvalue");
318 auto* status = methods->GetRecvStatus();
319 EXPECT_EQ(status->ok(), true);
320 }
321 if (methods->QueryInterceptionHookPoint(
322 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
323 EchoResponse* resp =
324 static_cast<EchoResponse*>(methods->GetRecvMessage());
325 resp->set_message(msg);
326 }
327 if (methods->QueryInterceptionHookPoint(
328 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
329 EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
330 ->message()
331 .find("Hello"),
332 0u);
333 }
334 if (methods->QueryInterceptionHookPoint(
335 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
336 auto* map = methods->GetRecvTrailingMetadata();
337 // insert the metadata that we want
338 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
339 map->insert(std::make_pair("testkey", "testvalue"));
340 auto* status = methods->GetRecvStatus();
341 *status = Status(StatusCode::OK, "");
342 }
343 if (hijack) {
344 methods->Hijack();
345 } else {
346 methods->Proceed();
347 }
348 }
349
350 private:
351 experimental::ClientRpcInfo* info_;
352 std::string msg;
353 };
354
355 class ClientStreamingRpcHijackingInterceptor
356 : public experimental::Interceptor {
357 public:
ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)358 ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
359 info_ = info;
360 }
Intercept(experimental::InterceptorBatchMethods * methods)361 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
362 bool hijack = false;
363 if (methods->QueryInterceptionHookPoint(
364 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
365 hijack = true;
366 }
367 if (methods->QueryInterceptionHookPoint(
368 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
369 if (++count_ > 10) {
370 methods->FailHijackedSendMessage();
371 }
372 }
373 if (methods->QueryInterceptionHookPoint(
374 experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
375 EXPECT_FALSE(got_failed_send_);
376 got_failed_send_ = !methods->GetSendMessageStatus();
377 }
378 if (methods->QueryInterceptionHookPoint(
379 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
380 auto* status = methods->GetRecvStatus();
381 *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
382 }
383 if (hijack) {
384 methods->Hijack();
385 } else {
386 methods->Proceed();
387 }
388 }
389
GotFailedSend()390 static bool GotFailedSend() { return got_failed_send_; }
391
392 private:
393 experimental::ClientRpcInfo* info_;
394 int count_ = 0;
395 static bool got_failed_send_;
396 };
397
398 bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
399
400 class ClientStreamingRpcHijackingInterceptorFactory
401 : public experimental::ClientInterceptorFactoryInterface {
402 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)403 virtual experimental::Interceptor* CreateClientInterceptor(
404 experimental::ClientRpcInfo* info) override {
405 return new ClientStreamingRpcHijackingInterceptor(info);
406 }
407 };
408
409 class ServerStreamingRpcHijackingInterceptor
410 : public experimental::Interceptor {
411 public:
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)412 ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
413 info_ = info;
414 got_failed_message_ = false;
415 }
416
Intercept(experimental::InterceptorBatchMethods * methods)417 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
418 bool hijack = false;
419 if (methods->QueryInterceptionHookPoint(
420 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
421 auto* map = methods->GetSendInitialMetadata();
422 // Check that we can see the test metadata
423 ASSERT_EQ(map->size(), static_cast<unsigned>(1));
424 auto iterator = map->begin();
425 EXPECT_EQ("testkey", iterator->first);
426 EXPECT_EQ("testvalue", iterator->second);
427 hijack = true;
428 }
429 if (methods->QueryInterceptionHookPoint(
430 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
431 EchoRequest req;
432 auto* buffer = methods->GetSerializedSendMessage();
433 auto copied_buffer = *buffer;
434 EXPECT_TRUE(
435 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
436 .ok());
437 EXPECT_EQ(req.message(), "Hello");
438 }
439 if (methods->QueryInterceptionHookPoint(
440 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
441 // Got nothing to do here for now
442 }
443 if (methods->QueryInterceptionHookPoint(
444 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
445 auto* map = methods->GetRecvTrailingMetadata();
446 bool found = false;
447 // Check that we received the metadata as an echo
448 for (const auto& pair : *map) {
449 found = pair.first.starts_with("testkey") &&
450 pair.second.starts_with("testvalue");
451 if (found) break;
452 }
453 EXPECT_EQ(found, true);
454 auto* status = methods->GetRecvStatus();
455 EXPECT_EQ(status->ok(), true);
456 }
457 if (methods->QueryInterceptionHookPoint(
458 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
459 if (++count_ > 10) {
460 methods->FailHijackedRecvMessage();
461 }
462 EchoResponse* resp =
463 static_cast<EchoResponse*>(methods->GetRecvMessage());
464 resp->set_message("Hello");
465 }
466 if (methods->QueryInterceptionHookPoint(
467 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
468 // Only the last message will be a failure
469 EXPECT_FALSE(got_failed_message_);
470 got_failed_message_ = methods->GetRecvMessage() == nullptr;
471 }
472 if (methods->QueryInterceptionHookPoint(
473 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
474 auto* map = methods->GetRecvTrailingMetadata();
475 // insert the metadata that we want
476 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
477 map->insert(std::make_pair("testkey", "testvalue"));
478 auto* status = methods->GetRecvStatus();
479 *status = Status(StatusCode::OK, "");
480 }
481 if (hijack) {
482 methods->Hijack();
483 } else {
484 methods->Proceed();
485 }
486 }
487
GotFailedMessage()488 static bool GotFailedMessage() { return got_failed_message_; }
489
490 private:
491 experimental::ClientRpcInfo* info_;
492 static bool got_failed_message_;
493 int count_ = 0;
494 };
495
496 bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
497
498 class ServerStreamingRpcHijackingInterceptorFactory
499 : public experimental::ClientInterceptorFactoryInterface {
500 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)501 virtual experimental::Interceptor* CreateClientInterceptor(
502 experimental::ClientRpcInfo* info) override {
503 return new ServerStreamingRpcHijackingInterceptor(info);
504 }
505 };
506
507 class BidiStreamingRpcHijackingInterceptorFactory
508 : public experimental::ClientInterceptorFactoryInterface {
509 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)510 virtual experimental::Interceptor* CreateClientInterceptor(
511 experimental::ClientRpcInfo* info) override {
512 return new BidiStreamingRpcHijackingInterceptor(info);
513 }
514 };
515
516 // The logging interceptor is for testing purposes only. It is used to verify
517 // that all the appropriate hook points are invoked for an RPC. The counts are
518 // reset each time a new object of LoggingInterceptor is created, so only a
519 // single RPC should be made on the channel before calling the Verify methods.
520 class LoggingInterceptor : public experimental::Interceptor {
521 public:
LoggingInterceptor(experimental::ClientRpcInfo *)522 LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) {
523 pre_send_initial_metadata_ = false;
524 pre_send_message_count_ = 0;
525 pre_send_close_ = false;
526 post_recv_initial_metadata_ = false;
527 post_recv_message_count_ = 0;
528 post_recv_status_ = false;
529 }
530
Intercept(experimental::InterceptorBatchMethods * methods)531 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
532 if (methods->QueryInterceptionHookPoint(
533 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
534 auto* map = methods->GetSendInitialMetadata();
535 // Check that we can see the test metadata
536 ASSERT_EQ(map->size(), static_cast<unsigned>(1));
537 auto iterator = map->begin();
538 EXPECT_EQ("testkey", iterator->first);
539 EXPECT_EQ("testvalue", iterator->second);
540 ASSERT_FALSE(pre_send_initial_metadata_);
541 pre_send_initial_metadata_ = true;
542 }
543 if (methods->QueryInterceptionHookPoint(
544 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
545 EchoRequest req;
546 auto* send_msg = methods->GetSendMessage();
547 if (send_msg == nullptr) {
548 // We did not get the non-serialized form of the message. Get the
549 // serialized form.
550 auto* buffer = methods->GetSerializedSendMessage();
551 auto copied_buffer = *buffer;
552 EchoRequest req;
553 EXPECT_TRUE(
554 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
555 .ok());
556 EXPECT_EQ(req.message(), "Hello");
557 } else {
558 EXPECT_EQ(
559 static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
560 0u);
561 }
562 auto* buffer = methods->GetSerializedSendMessage();
563 auto copied_buffer = *buffer;
564 EXPECT_TRUE(
565 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
566 .ok());
567 EXPECT_TRUE(req.message().find("Hello") == 0u);
568 pre_send_message_count_++;
569 }
570 if (methods->QueryInterceptionHookPoint(
571 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
572 // Got nothing to do here for now
573 pre_send_close_ = true;
574 }
575 if (methods->QueryInterceptionHookPoint(
576 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
577 auto* map = methods->GetRecvInitialMetadata();
578 // Got nothing better to do here for now
579 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
580 post_recv_initial_metadata_ = true;
581 }
582 if (methods->QueryInterceptionHookPoint(
583 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
584 EchoResponse* resp =
585 static_cast<EchoResponse*>(methods->GetRecvMessage());
586 if (resp != nullptr) {
587 EXPECT_TRUE(resp->message().find("Hello") == 0u);
588 post_recv_message_count_++;
589 }
590 }
591 if (methods->QueryInterceptionHookPoint(
592 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
593 auto* map = methods->GetRecvTrailingMetadata();
594 bool found = false;
595 // Check that we received the metadata as an echo
596 for (const auto& pair : *map) {
597 found = pair.first.starts_with("testkey") &&
598 pair.second.starts_with("testvalue");
599 if (found) break;
600 }
601 EXPECT_EQ(found, true);
602 auto* status = methods->GetRecvStatus();
603 EXPECT_EQ(status->ok(), true);
604 post_recv_status_ = true;
605 }
606 methods->Proceed();
607 }
608
VerifyCall(RPCType type)609 static void VerifyCall(RPCType type) {
610 switch (type) {
611 case RPCType::kSyncUnary:
612 case RPCType::kAsyncCQUnary:
613 VerifyUnaryCall();
614 break;
615 case RPCType::kSyncClientStreaming:
616 case RPCType::kAsyncCQClientStreaming:
617 VerifyClientStreamingCall();
618 break;
619 case RPCType::kSyncServerStreaming:
620 case RPCType::kAsyncCQServerStreaming:
621 VerifyServerStreamingCall();
622 break;
623 case RPCType::kSyncBidiStreaming:
624 case RPCType::kAsyncCQBidiStreaming:
625 VerifyBidiStreamingCall();
626 break;
627 }
628 }
629
VerifyCallCommon()630 static void VerifyCallCommon() {
631 EXPECT_TRUE(pre_send_initial_metadata_);
632 EXPECT_TRUE(pre_send_close_);
633 EXPECT_TRUE(post_recv_initial_metadata_);
634 EXPECT_TRUE(post_recv_status_);
635 }
636
VerifyUnaryCall()637 static void VerifyUnaryCall() {
638 VerifyCallCommon();
639 EXPECT_EQ(pre_send_message_count_, 1);
640 EXPECT_EQ(post_recv_message_count_, 1);
641 }
642
VerifyClientStreamingCall()643 static void VerifyClientStreamingCall() {
644 VerifyCallCommon();
645 EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
646 EXPECT_EQ(post_recv_message_count_, 1);
647 }
648
VerifyServerStreamingCall()649 static void VerifyServerStreamingCall() {
650 VerifyCallCommon();
651 EXPECT_EQ(pre_send_message_count_, 1);
652 EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
653 }
654
VerifyBidiStreamingCall()655 static void VerifyBidiStreamingCall() {
656 VerifyCallCommon();
657 EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
658 EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
659 }
660
661 private:
662 static bool pre_send_initial_metadata_;
663 static int pre_send_message_count_;
664 static bool pre_send_close_;
665 static bool post_recv_initial_metadata_;
666 static int post_recv_message_count_;
667 static bool post_recv_status_;
668 };
669
670 bool LoggingInterceptor::pre_send_initial_metadata_;
671 int LoggingInterceptor::pre_send_message_count_;
672 bool LoggingInterceptor::pre_send_close_;
673 bool LoggingInterceptor::post_recv_initial_metadata_;
674 int LoggingInterceptor::post_recv_message_count_;
675 bool LoggingInterceptor::post_recv_status_;
676
677 class LoggingInterceptorFactory
678 : public experimental::ClientInterceptorFactoryInterface {
679 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)680 virtual experimental::Interceptor* CreateClientInterceptor(
681 experimental::ClientRpcInfo* info) override {
682 return new LoggingInterceptor(info);
683 }
684 };
685
686 class TestScenario {
687 public:
TestScenario(const RPCType & type)688 explicit TestScenario(const RPCType& type) : type_(type) {}
689
type() const690 RPCType type() const { return type_; }
691
692 private:
693 RPCType type_;
694 };
695
CreateTestScenarios()696 std::vector<TestScenario> CreateTestScenarios() {
697 std::vector<TestScenario> scenarios;
698 scenarios.emplace_back(RPCType::kSyncUnary);
699 scenarios.emplace_back(RPCType::kSyncClientStreaming);
700 scenarios.emplace_back(RPCType::kSyncServerStreaming);
701 scenarios.emplace_back(RPCType::kSyncBidiStreaming);
702 scenarios.emplace_back(RPCType::kAsyncCQUnary);
703 scenarios.emplace_back(RPCType::kAsyncCQServerStreaming);
704 return scenarios;
705 }
706
707 class ParameterizedClientInterceptorsEnd2endTest
708 : public ::testing::TestWithParam<TestScenario> {
709 protected:
ParameterizedClientInterceptorsEnd2endTest()710 ParameterizedClientInterceptorsEnd2endTest() {
711 int port = grpc_pick_unused_port_or_die();
712
713 ServerBuilder builder;
714 server_address_ = "localhost:" + std::to_string(port);
715 builder.AddListeningPort(server_address_, InsecureServerCredentials());
716 builder.RegisterService(&service_);
717 server_ = builder.BuildAndStart();
718 }
719
~ParameterizedClientInterceptorsEnd2endTest()720 ~ParameterizedClientInterceptorsEnd2endTest() { server_->Shutdown(); }
721
SendRPC(const std::shared_ptr<Channel> & channel)722 void SendRPC(const std::shared_ptr<Channel>& channel) {
723 switch (GetParam().type()) {
724 case RPCType::kSyncUnary:
725 MakeCall(channel);
726 break;
727 case RPCType::kSyncClientStreaming:
728 MakeClientStreamingCall(channel);
729 break;
730 case RPCType::kSyncServerStreaming:
731 MakeServerStreamingCall(channel);
732 break;
733 case RPCType::kSyncBidiStreaming:
734 MakeBidiStreamingCall(channel);
735 break;
736 case RPCType::kAsyncCQUnary:
737 MakeAsyncCQCall(channel);
738 break;
739 case RPCType::kAsyncCQClientStreaming:
740 // TODO(yashykt) : Fill this out
741 break;
742 case RPCType::kAsyncCQServerStreaming:
743 MakeAsyncCQServerStreamingCall(channel);
744 break;
745 case RPCType::kAsyncCQBidiStreaming:
746 // TODO(yashykt) : Fill this out
747 break;
748 }
749 }
750
751 std::string server_address_;
752 EchoTestServiceStreamingImpl service_;
753 std::unique_ptr<Server> server_;
754 };
755
TEST_P(ParameterizedClientInterceptorsEnd2endTest,ClientInterceptorLoggingTest)756 TEST_P(ParameterizedClientInterceptorsEnd2endTest,
757 ClientInterceptorLoggingTest) {
758 ChannelArguments args;
759 DummyInterceptor::Reset();
760 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
761 creators;
762 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
763 new LoggingInterceptorFactory()));
764 // Add 20 dummy interceptors
765 for (auto i = 0; i < 20; i++) {
766 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
767 new DummyInterceptorFactory()));
768 }
769 auto channel = experimental::CreateCustomChannelWithInterceptors(
770 server_address_, InsecureChannelCredentials(), args, std::move(creators));
771 SendRPC(channel);
772 LoggingInterceptor::VerifyCall(GetParam().type());
773 // Make sure all 20 dummy interceptors were run
774 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
775 }
776
777 INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
778 ParameterizedClientInterceptorsEnd2endTest,
779 ::testing::ValuesIn(CreateTestScenarios()));
780
781 class ClientInterceptorsEnd2endTest
782 : public ::testing::TestWithParam<TestScenario> {
783 protected:
ClientInterceptorsEnd2endTest()784 ClientInterceptorsEnd2endTest() {
785 int port = grpc_pick_unused_port_or_die();
786
787 ServerBuilder builder;
788 server_address_ = "localhost:" + std::to_string(port);
789 builder.AddListeningPort(server_address_, InsecureServerCredentials());
790 builder.RegisterService(&service_);
791 server_ = builder.BuildAndStart();
792 }
793
~ClientInterceptorsEnd2endTest()794 ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); }
795
796 std::string server_address_;
797 TestServiceImpl service_;
798 std::unique_ptr<Server> server_;
799 };
800
TEST_F(ClientInterceptorsEnd2endTest,LameChannelClientInterceptorHijackingTest)801 TEST_F(ClientInterceptorsEnd2endTest,
802 LameChannelClientInterceptorHijackingTest) {
803 ChannelArguments args;
804 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
805 creators;
806 creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
807 new HijackingInterceptorFactory()));
808 auto channel = experimental::CreateCustomChannelWithInterceptors(
809 server_address_, nullptr, args, std::move(creators));
810 MakeCall(channel);
811 }
812
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingTest)813 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
814 ChannelArguments args;
815 DummyInterceptor::Reset();
816 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
817 creators;
818 // Add 20 dummy interceptors before hijacking interceptor
819 creators.reserve(20);
820 for (auto i = 0; i < 20; i++) {
821 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
822 new DummyInterceptorFactory()));
823 }
824 creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
825 new HijackingInterceptorFactory()));
826 // Add 20 dummy interceptors after hijacking interceptor
827 for (auto i = 0; i < 20; i++) {
828 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
829 new DummyInterceptorFactory()));
830 }
831 auto channel = experimental::CreateCustomChannelWithInterceptors(
832 server_address_, InsecureChannelCredentials(), args, std::move(creators));
833 MakeCall(channel);
834 // Make sure only 20 dummy interceptors were run
835 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
836 }
837
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorLogThenHijackTest)838 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
839 ChannelArguments args;
840 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
841 creators;
842 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
843 new LoggingInterceptorFactory()));
844 creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
845 new HijackingInterceptorFactory()));
846 auto channel = experimental::CreateCustomChannelWithInterceptors(
847 server_address_, InsecureChannelCredentials(), args, std::move(creators));
848 MakeCall(channel);
849 LoggingInterceptor::VerifyUnaryCall();
850 }
851
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingMakesAnotherCallTest)852 TEST_F(ClientInterceptorsEnd2endTest,
853 ClientInterceptorHijackingMakesAnotherCallTest) {
854 ChannelArguments args;
855 DummyInterceptor::Reset();
856 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
857 creators;
858 // Add 5 dummy interceptors before hijacking interceptor
859 creators.reserve(5);
860 for (auto i = 0; i < 5; i++) {
861 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
862 new DummyInterceptorFactory()));
863 }
864 creators.push_back(
865 std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
866 new HijackingInterceptorMakesAnotherCallFactory()));
867 // Add 7 dummy interceptors after hijacking interceptor
868 for (auto i = 0; i < 7; i++) {
869 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
870 new DummyInterceptorFactory()));
871 }
872 auto channel = server_->experimental().InProcessChannelWithInterceptors(
873 args, std::move(creators));
874
875 MakeCall(channel);
876 // Make sure all interceptors were run once, since the hijacking interceptor
877 // makes an RPC on the intercepted channel
878 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12);
879 }
880
881 class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
882 protected:
ClientInterceptorsCallbackEnd2endTest()883 ClientInterceptorsCallbackEnd2endTest() {
884 int port = grpc_pick_unused_port_or_die();
885
886 ServerBuilder builder;
887 server_address_ = "localhost:" + std::to_string(port);
888 builder.AddListeningPort(server_address_, InsecureServerCredentials());
889 builder.RegisterService(&service_);
890 server_ = builder.BuildAndStart();
891 }
892
~ClientInterceptorsCallbackEnd2endTest()893 ~ClientInterceptorsCallbackEnd2endTest() { server_->Shutdown(); }
894
895 std::string server_address_;
896 TestServiceImpl service_;
897 std::unique_ptr<Server> server_;
898 };
899
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorLoggingTestWithCallback)900 TEST_F(ClientInterceptorsCallbackEnd2endTest,
901 ClientInterceptorLoggingTestWithCallback) {
902 ChannelArguments args;
903 DummyInterceptor::Reset();
904 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
905 creators;
906 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
907 new LoggingInterceptorFactory()));
908 // Add 20 dummy interceptors
909 for (auto i = 0; i < 20; i++) {
910 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
911 new DummyInterceptorFactory()));
912 }
913 auto channel = server_->experimental().InProcessChannelWithInterceptors(
914 args, std::move(creators));
915 MakeCallbackCall(channel);
916 LoggingInterceptor::VerifyUnaryCall();
917 // Make sure all 20 dummy interceptors were run
918 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
919 }
920
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorFactoryAllowsNullptrReturn)921 TEST_F(ClientInterceptorsCallbackEnd2endTest,
922 ClientInterceptorFactoryAllowsNullptrReturn) {
923 ChannelArguments args;
924 DummyInterceptor::Reset();
925 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
926 creators;
927 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
928 new LoggingInterceptorFactory()));
929 // Add 20 dummy interceptors and 20 null interceptors
930 for (auto i = 0; i < 20; i++) {
931 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
932 new DummyInterceptorFactory()));
933 creators.push_back(
934 std::unique_ptr<NullInterceptorFactory>(new NullInterceptorFactory()));
935 }
936 auto channel = server_->experimental().InProcessChannelWithInterceptors(
937 args, std::move(creators));
938 MakeCallbackCall(channel);
939 LoggingInterceptor::VerifyUnaryCall();
940 // Make sure all 20 dummy interceptors were run
941 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
942 }
943
944 class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
945 protected:
ClientInterceptorsStreamingEnd2endTest()946 ClientInterceptorsStreamingEnd2endTest() {
947 int port = grpc_pick_unused_port_or_die();
948
949 ServerBuilder builder;
950 server_address_ = "localhost:" + std::to_string(port);
951 builder.AddListeningPort(server_address_, InsecureServerCredentials());
952 builder.RegisterService(&service_);
953 server_ = builder.BuildAndStart();
954 }
955
~ClientInterceptorsStreamingEnd2endTest()956 ~ClientInterceptorsStreamingEnd2endTest() { server_->Shutdown(); }
957
958 std::string server_address_;
959 EchoTestServiceStreamingImpl service_;
960 std::unique_ptr<Server> server_;
961 };
962
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingTest)963 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
964 ChannelArguments args;
965 DummyInterceptor::Reset();
966 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
967 creators;
968 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
969 new LoggingInterceptorFactory()));
970 // Add 20 dummy interceptors
971 for (auto i = 0; i < 20; i++) {
972 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
973 new DummyInterceptorFactory()));
974 }
975 auto channel = experimental::CreateCustomChannelWithInterceptors(
976 server_address_, InsecureChannelCredentials(), args, std::move(creators));
977 MakeClientStreamingCall(channel);
978 LoggingInterceptor::VerifyClientStreamingCall();
979 // Make sure all 20 dummy interceptors were run
980 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
981 }
982
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingTest)983 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
984 ChannelArguments args;
985 DummyInterceptor::Reset();
986 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
987 creators;
988 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
989 new LoggingInterceptorFactory()));
990 // Add 20 dummy interceptors
991 for (auto i = 0; i < 20; i++) {
992 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
993 new DummyInterceptorFactory()));
994 }
995 auto channel = experimental::CreateCustomChannelWithInterceptors(
996 server_address_, InsecureChannelCredentials(), args, std::move(creators));
997 MakeServerStreamingCall(channel);
998 LoggingInterceptor::VerifyServerStreamingCall();
999 // Make sure all 20 dummy interceptors were run
1000 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1001 }
1002
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingHijackingTest)1003 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
1004 ChannelArguments args;
1005 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1006 creators;
1007 creators.push_back(
1008 std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
1009 new ClientStreamingRpcHijackingInterceptorFactory()));
1010 auto channel = experimental::CreateCustomChannelWithInterceptors(
1011 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1012
1013 auto stub = grpc::testing::EchoTestService::NewStub(channel);
1014 ClientContext ctx;
1015 EchoRequest req;
1016 EchoResponse resp;
1017 req.mutable_param()->set_echo_metadata(true);
1018 req.set_message("Hello");
1019 string expected_resp = "";
1020 auto writer = stub->RequestStream(&ctx, &resp);
1021 for (int i = 0; i < 10; i++) {
1022 EXPECT_TRUE(writer->Write(req));
1023 expected_resp += "Hello";
1024 }
1025 // The interceptor will reject the 11th message
1026 writer->Write(req);
1027 Status s = writer->Finish();
1028 EXPECT_EQ(s.ok(), false);
1029 EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
1030 }
1031
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingHijackingTest)1032 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
1033 ChannelArguments args;
1034 DummyInterceptor::Reset();
1035 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1036 creators;
1037 creators.push_back(
1038 std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
1039 new ServerStreamingRpcHijackingInterceptorFactory()));
1040 auto channel = experimental::CreateCustomChannelWithInterceptors(
1041 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1042 MakeServerStreamingCall(channel);
1043 EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1044 }
1045
TEST_F(ClientInterceptorsStreamingEnd2endTest,AsyncCQServerStreamingHijackingTest)1046 TEST_F(ClientInterceptorsStreamingEnd2endTest,
1047 AsyncCQServerStreamingHijackingTest) {
1048 ChannelArguments args;
1049 DummyInterceptor::Reset();
1050 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1051 creators;
1052 creators.push_back(
1053 std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
1054 new ServerStreamingRpcHijackingInterceptorFactory()));
1055 auto channel = experimental::CreateCustomChannelWithInterceptors(
1056 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1057 MakeAsyncCQServerStreamingCall(channel);
1058 EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1059 }
1060
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingHijackingTest)1061 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
1062 ChannelArguments args;
1063 DummyInterceptor::Reset();
1064 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1065 creators;
1066 creators.push_back(
1067 std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
1068 new BidiStreamingRpcHijackingInterceptorFactory()));
1069 auto channel = experimental::CreateCustomChannelWithInterceptors(
1070 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1071 MakeBidiStreamingCall(channel);
1072 }
1073
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingTest)1074 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
1075 ChannelArguments args;
1076 DummyInterceptor::Reset();
1077 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1078 creators;
1079 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
1080 new LoggingInterceptorFactory()));
1081 // Add 20 dummy interceptors
1082 for (auto i = 0; i < 20; i++) {
1083 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1084 new DummyInterceptorFactory()));
1085 }
1086 auto channel = experimental::CreateCustomChannelWithInterceptors(
1087 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1088 MakeBidiStreamingCall(channel);
1089 LoggingInterceptor::VerifyBidiStreamingCall();
1090 // Make sure all 20 dummy interceptors were run
1091 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1092 }
1093
1094 class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
1095 protected:
ClientGlobalInterceptorEnd2endTest()1096 ClientGlobalInterceptorEnd2endTest() {
1097 int port = grpc_pick_unused_port_or_die();
1098
1099 ServerBuilder builder;
1100 server_address_ = "localhost:" + std::to_string(port);
1101 builder.AddListeningPort(server_address_, InsecureServerCredentials());
1102 builder.RegisterService(&service_);
1103 server_ = builder.BuildAndStart();
1104 }
1105
~ClientGlobalInterceptorEnd2endTest()1106 ~ClientGlobalInterceptorEnd2endTest() { server_->Shutdown(); }
1107
1108 std::string server_address_;
1109 TestServiceImpl service_;
1110 std::unique_ptr<Server> server_;
1111 };
1112
TEST_F(ClientGlobalInterceptorEnd2endTest,DummyGlobalInterceptor)1113 TEST_F(ClientGlobalInterceptorEnd2endTest, DummyGlobalInterceptor) {
1114 // We should ideally be registering a global interceptor only once per
1115 // process, but for the purposes of testing, it should be fine to modify the
1116 // registered global interceptor when there are no ongoing gRPC operations
1117 DummyInterceptorFactory global_factory;
1118 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1119 ChannelArguments args;
1120 DummyInterceptor::Reset();
1121 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1122 creators;
1123 // Add 20 dummy interceptors
1124 creators.reserve(20);
1125 for (auto i = 0; i < 20; i++) {
1126 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1127 new DummyInterceptorFactory()));
1128 }
1129 auto channel = experimental::CreateCustomChannelWithInterceptors(
1130 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1131 MakeCall(channel);
1132 // Make sure all 20 dummy interceptors were run with the global interceptor
1133 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 21);
1134 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1135 }
1136
TEST_F(ClientGlobalInterceptorEnd2endTest,LoggingGlobalInterceptor)1137 TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
1138 // We should ideally be registering a global interceptor only once per
1139 // process, but for the purposes of testing, it should be fine to modify the
1140 // registered global interceptor when there are no ongoing gRPC operations
1141 LoggingInterceptorFactory global_factory;
1142 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1143 ChannelArguments args;
1144 DummyInterceptor::Reset();
1145 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1146 creators;
1147 // Add 20 dummy interceptors
1148 creators.reserve(20);
1149 for (auto i = 0; i < 20; i++) {
1150 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1151 new DummyInterceptorFactory()));
1152 }
1153 auto channel = experimental::CreateCustomChannelWithInterceptors(
1154 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1155 MakeCall(channel);
1156 LoggingInterceptor::VerifyUnaryCall();
1157 // Make sure all 20 dummy interceptors were run
1158 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1159 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1160 }
1161
TEST_F(ClientGlobalInterceptorEnd2endTest,HijackingGlobalInterceptor)1162 TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
1163 // We should ideally be registering a global interceptor only once per
1164 // process, but for the purposes of testing, it should be fine to modify the
1165 // registered global interceptor when there are no ongoing gRPC operations
1166 HijackingInterceptorFactory global_factory;
1167 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1168 ChannelArguments args;
1169 DummyInterceptor::Reset();
1170 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1171 creators;
1172 // Add 20 dummy interceptors
1173 creators.reserve(20);
1174 for (auto i = 0; i < 20; i++) {
1175 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1176 new DummyInterceptorFactory()));
1177 }
1178 auto channel = experimental::CreateCustomChannelWithInterceptors(
1179 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1180 MakeCall(channel);
1181 // Make sure all 20 dummy interceptors were run
1182 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1183 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1184 }
1185
1186 } // namespace
1187 } // namespace testing
1188 } // namespace grpc
1189
main(int argc,char ** argv)1190 int main(int argc, char** argv) {
1191 grpc::testing::TestEnvironment env(argc, argv);
1192 ::testing::InitGoogleTest(&argc, argv);
1193 return RUN_ALL_TESTS();
1194 }
1195