1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include <grpcpp/channel.h>
20 #include <grpcpp/client_context.h>
21 #include <grpcpp/create_channel.h>
22 #include <grpcpp/create_channel_posix.h>
23 #include <grpcpp/generic/generic_stub.h>
24 #include <grpcpp/impl/proto_utils.h>
25 #include <grpcpp/server.h>
26 #include <grpcpp/server_builder.h>
27 #include <grpcpp/server_context.h>
28 #include <grpcpp/server_posix.h>
29 #include <grpcpp/support/client_interceptor.h>
30 #include <gtest/gtest.h>
31
32 #include <memory>
33 #include <vector>
34
35 #include "absl/log/check.h"
36 #include "absl/memory/memory.h"
37 #include "src/core/lib/iomgr/port.h"
38 #include "src/proto/grpc/testing/echo.grpc.pb.h"
39 #include "test/core/test_util/port.h"
40 #include "test/core/test_util/test_config.h"
41 #include "test/cpp/end2end/interceptors_util.h"
42 #include "test/cpp/end2end/test_service_impl.h"
43 #include "test/cpp/util/byte_buffer_proto_helper.h"
44 #include "test/cpp/util/string_ref_helper.h"
45
46 #ifdef GRPC_POSIX_SOCKET
47 #include <fcntl.h>
48
49 #include "src/core/lib/iomgr/socket_utils_posix.h"
50 #endif // GRPC_POSIX_SOCKET
51
52 namespace grpc {
53 namespace testing {
54 namespace {
55
56 enum class RPCType {
57 kSyncUnary,
58 kSyncClientStreaming,
59 kSyncServerStreaming,
60 kSyncBidiStreaming,
61 kAsyncCQUnary,
62 kAsyncCQClientStreaming,
63 kAsyncCQServerStreaming,
64 kAsyncCQBidiStreaming,
65 };
66
67 enum class ChannelType {
68 kHttpChannel,
69 kFdChannel,
70 };
71
72 // Hijacks Echo RPC and fills in the expected values
73 class HijackingInterceptor : public experimental::Interceptor {
74 public:
HijackingInterceptor(experimental::ClientRpcInfo * info)75 explicit HijackingInterceptor(experimental::ClientRpcInfo* info) {
76 info_ = info;
77 // Make sure it is the right method
78 EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
79 EXPECT_EQ(info->suffix_for_stats(), nullptr);
80 EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
81 }
82
Intercept(experimental::InterceptorBatchMethods * methods)83 void Intercept(experimental::InterceptorBatchMethods* methods) override {
84 bool hijack = false;
85 if (methods->QueryInterceptionHookPoint(
86 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
87 auto* map = methods->GetSendInitialMetadata();
88 // Check that we can see the test metadata
89 ASSERT_EQ(map->size(), 1);
90 auto iterator = map->begin();
91 EXPECT_EQ("testkey", iterator->first);
92 EXPECT_EQ("testvalue", iterator->second);
93 hijack = true;
94 }
95 if (methods->QueryInterceptionHookPoint(
96 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
97 EchoRequest req;
98 auto* buffer = methods->GetSerializedSendMessage();
99 auto copied_buffer = *buffer;
100 EXPECT_TRUE(
101 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
102 .ok());
103 EXPECT_EQ(req.message(), "Hello");
104 }
105 if (methods->QueryInterceptionHookPoint(
106 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
107 // Got nothing to do here for now
108 }
109 if (methods->QueryInterceptionHookPoint(
110 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
111 auto* map = methods->GetRecvInitialMetadata();
112 // Got nothing better to do here for now
113 EXPECT_EQ(map->size(), 0);
114 }
115 if (methods->QueryInterceptionHookPoint(
116 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
117 EchoResponse* resp =
118 static_cast<EchoResponse*>(methods->GetRecvMessage());
119 // Check that we got the hijacked message, and re-insert the expected
120 // message
121 EXPECT_EQ(resp->message(), "Hello1");
122 resp->set_message("Hello");
123 }
124 if (methods->QueryInterceptionHookPoint(
125 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
126 auto* map = methods->GetRecvTrailingMetadata();
127 bool found = false;
128 // Check that we received the metadata as an echo
129 for (const auto& pair : *map) {
130 found = pair.first.starts_with("testkey") &&
131 pair.second.starts_with("testvalue");
132 if (found) break;
133 }
134 EXPECT_EQ(found, true);
135 auto* status = methods->GetRecvStatus();
136 EXPECT_EQ(status->ok(), true);
137 }
138 if (methods->QueryInterceptionHookPoint(
139 experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
140 auto* map = methods->GetRecvInitialMetadata();
141 // Got nothing better to do here at the moment
142 EXPECT_EQ(map->size(), 0);
143 }
144 if (methods->QueryInterceptionHookPoint(
145 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
146 // Insert a different message than expected
147 EchoResponse* resp =
148 static_cast<EchoResponse*>(methods->GetRecvMessage());
149 resp->set_message("Hello1");
150 }
151 if (methods->QueryInterceptionHookPoint(
152 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
153 auto* map = methods->GetRecvTrailingMetadata();
154 // insert the metadata that we want
155 EXPECT_EQ(map->size(), 0);
156 map->insert(std::make_pair("testkey", "testvalue"));
157 auto* status = methods->GetRecvStatus();
158 *status = Status(StatusCode::OK, "");
159 }
160 if (hijack) {
161 methods->Hijack();
162 } else {
163 methods->Proceed();
164 }
165 }
166
167 private:
168 experimental::ClientRpcInfo* info_;
169 };
170
171 class HijackingInterceptorFactory
172 : public experimental::ClientInterceptorFactoryInterface {
173 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)174 experimental::Interceptor* CreateClientInterceptor(
175 experimental::ClientRpcInfo* info) override {
176 return new HijackingInterceptor(info);
177 }
178 };
179
180 class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
181 public:
HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo * info)182 explicit HijackingInterceptorMakesAnotherCall(
183 experimental::ClientRpcInfo* info) {
184 info_ = info;
185 // Make sure it is the right method
186 EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
187 EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
188 }
189
Intercept(experimental::InterceptorBatchMethods * methods)190 void Intercept(experimental::InterceptorBatchMethods* methods) override {
191 if (methods->QueryInterceptionHookPoint(
192 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
193 auto* map = methods->GetSendInitialMetadata();
194 // Check that we can see the test metadata
195 ASSERT_EQ(map->size(), 1);
196 auto iterator = map->begin();
197 EXPECT_EQ("testkey", iterator->first);
198 EXPECT_EQ("testvalue", iterator->second);
199 // Make a copy of the map
200 metadata_map_ = *map;
201 }
202 if (methods->QueryInterceptionHookPoint(
203 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
204 EchoRequest req;
205 auto* buffer = methods->GetSerializedSendMessage();
206 auto copied_buffer = *buffer;
207 EXPECT_TRUE(
208 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
209 .ok());
210 EXPECT_EQ(req.message(), "Hello");
211 req_ = req;
212 stub_ = grpc::testing::EchoTestService::NewStub(
213 methods->GetInterceptedChannel());
214 ctx_.AddMetadata(metadata_map_.begin()->first,
215 metadata_map_.begin()->second);
216 stub_->async()->Echo(&ctx_, &req_, &resp_, [this, methods](Status s) {
217 EXPECT_EQ(s.ok(), true);
218 EXPECT_EQ(resp_.message(), "Hello");
219 methods->Hijack();
220 });
221 // This is a Unary RPC and we have got nothing interesting to do in the
222 // PRE_SEND_CLOSE interception hook point for this interceptor, so let's
223 // return here. (We do not want to call methods->Proceed(). When the new
224 // RPC returns, we will call methods->Hijack() instead.)
225 return;
226 }
227 if (methods->QueryInterceptionHookPoint(
228 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
229 // Got nothing to do here for now
230 }
231 if (methods->QueryInterceptionHookPoint(
232 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
233 auto* map = methods->GetRecvInitialMetadata();
234 // Got nothing better to do here for now
235 EXPECT_EQ(map->size(), 0);
236 }
237 if (methods->QueryInterceptionHookPoint(
238 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
239 EchoResponse* resp =
240 static_cast<EchoResponse*>(methods->GetRecvMessage());
241 // Check that we got the hijacked message, and re-insert the expected
242 // message
243 EXPECT_EQ(resp->message(), "Hello");
244 }
245 if (methods->QueryInterceptionHookPoint(
246 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
247 auto* map = methods->GetRecvTrailingMetadata();
248 bool found = false;
249 // Check that we received the metadata as an echo
250 for (const auto& pair : *map) {
251 found = pair.first.starts_with("testkey") &&
252 pair.second.starts_with("testvalue");
253 if (found) break;
254 }
255 EXPECT_EQ(found, true);
256 auto* status = methods->GetRecvStatus();
257 EXPECT_EQ(status->ok(), true);
258 }
259 if (methods->QueryInterceptionHookPoint(
260 experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
261 auto* map = methods->GetRecvInitialMetadata();
262 // Got nothing better to do here at the moment
263 EXPECT_EQ(map->size(), 0);
264 }
265 if (methods->QueryInterceptionHookPoint(
266 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
267 // Insert a different message than expected
268 EchoResponse* resp =
269 static_cast<EchoResponse*>(methods->GetRecvMessage());
270 resp->set_message(resp_.message());
271 }
272 if (methods->QueryInterceptionHookPoint(
273 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
274 auto* map = methods->GetRecvTrailingMetadata();
275 // insert the metadata that we want
276 EXPECT_EQ(map->size(), 0);
277 map->insert(std::make_pair("testkey", "testvalue"));
278 auto* status = methods->GetRecvStatus();
279 *status = Status(StatusCode::OK, "");
280 }
281
282 methods->Proceed();
283 }
284
285 private:
286 experimental::ClientRpcInfo* info_;
287 std::multimap<std::string, std::string> metadata_map_;
288 ClientContext ctx_;
289 EchoRequest req_;
290 EchoResponse resp_;
291 std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
292 };
293
294 class HijackingInterceptorMakesAnotherCallFactory
295 : public experimental::ClientInterceptorFactoryInterface {
296 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)297 experimental::Interceptor* CreateClientInterceptor(
298 experimental::ClientRpcInfo* info) override {
299 return new HijackingInterceptorMakesAnotherCall(info);
300 }
301 };
302
303 class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
304 public:
BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)305 explicit BidiStreamingRpcHijackingInterceptor(
306 experimental::ClientRpcInfo* info) {
307 info_ = info;
308 EXPECT_EQ(info->suffix_for_stats(), nullptr);
309 }
310
Intercept(experimental::InterceptorBatchMethods * methods)311 void Intercept(experimental::InterceptorBatchMethods* methods) override {
312 bool hijack = false;
313 if (methods->QueryInterceptionHookPoint(
314 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
315 CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
316 hijack = true;
317 }
318 if (methods->QueryInterceptionHookPoint(
319 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
320 EchoRequest req;
321 auto* buffer = methods->GetSerializedSendMessage();
322 auto copied_buffer = *buffer;
323 EXPECT_TRUE(
324 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
325 .ok());
326 EXPECT_EQ(req.message().find("Hello"), 0u);
327 msg = req.message();
328 }
329 if (methods->QueryInterceptionHookPoint(
330 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
331 // Got nothing to do here for now
332 }
333 if (methods->QueryInterceptionHookPoint(
334 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
335 CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
336 "testvalue");
337 auto* status = methods->GetRecvStatus();
338 EXPECT_EQ(status->ok(), true);
339 }
340 if (methods->QueryInterceptionHookPoint(
341 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
342 EchoResponse* resp =
343 static_cast<EchoResponse*>(methods->GetRecvMessage());
344 resp->set_message(msg);
345 }
346 if (methods->QueryInterceptionHookPoint(
347 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
348 EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
349 ->message()
350 .find("Hello"),
351 0u);
352 }
353 if (methods->QueryInterceptionHookPoint(
354 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
355 auto* map = methods->GetRecvTrailingMetadata();
356 // insert the metadata that we want
357 EXPECT_EQ(map->size(), 0);
358 map->insert(std::make_pair("testkey", "testvalue"));
359 auto* status = methods->GetRecvStatus();
360 *status = Status(StatusCode::OK, "");
361 }
362 if (hijack) {
363 methods->Hijack();
364 } else {
365 methods->Proceed();
366 }
367 }
368
369 private:
370 experimental::ClientRpcInfo* info_;
371 std::string msg;
372 };
373
374 class ClientStreamingRpcHijackingInterceptor
375 : public experimental::Interceptor {
376 public:
ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)377 explicit ClientStreamingRpcHijackingInterceptor(
378 experimental::ClientRpcInfo* info) {
379 info_ = info;
380 EXPECT_EQ(
381 strcmp("/grpc.testing.EchoTestService/RequestStream", info->method()),
382 0);
383 EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
384 }
Intercept(experimental::InterceptorBatchMethods * methods)385 void Intercept(experimental::InterceptorBatchMethods* methods) override {
386 bool hijack = false;
387 if (methods->QueryInterceptionHookPoint(
388 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
389 hijack = true;
390 }
391 if (methods->QueryInterceptionHookPoint(
392 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
393 if (++count_ > 10) {
394 methods->FailHijackedSendMessage();
395 }
396 }
397 if (methods->QueryInterceptionHookPoint(
398 experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
399 EXPECT_FALSE(got_failed_send_);
400 got_failed_send_ = !methods->GetSendMessageStatus();
401 }
402 if (methods->QueryInterceptionHookPoint(
403 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
404 auto* status = methods->GetRecvStatus();
405 *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
406 }
407 if (hijack) {
408 methods->Hijack();
409 } else {
410 methods->Proceed();
411 }
412 }
413
GotFailedSend()414 static bool GotFailedSend() { return got_failed_send_; }
415
416 private:
417 experimental::ClientRpcInfo* info_;
418 int count_ = 0;
419 static bool got_failed_send_;
420 };
421
422 bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
423
424 class ClientStreamingRpcHijackingInterceptorFactory
425 : public experimental::ClientInterceptorFactoryInterface {
426 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)427 experimental::Interceptor* CreateClientInterceptor(
428 experimental::ClientRpcInfo* info) override {
429 return new ClientStreamingRpcHijackingInterceptor(info);
430 }
431 };
432
433 class ServerStreamingRpcHijackingInterceptor
434 : public experimental::Interceptor {
435 public:
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)436 explicit ServerStreamingRpcHijackingInterceptor(
437 experimental::ClientRpcInfo* info) {
438 info_ = info;
439 got_failed_message_ = false;
440 EXPECT_EQ(info->suffix_for_stats(), nullptr);
441 }
442
Intercept(experimental::InterceptorBatchMethods * methods)443 void Intercept(experimental::InterceptorBatchMethods* methods) override {
444 bool hijack = false;
445 if (methods->QueryInterceptionHookPoint(
446 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
447 auto* map = methods->GetSendInitialMetadata();
448 // Check that we can see the test metadata
449 ASSERT_EQ(map->size(), 1);
450 auto iterator = map->begin();
451 EXPECT_EQ("testkey", iterator->first);
452 EXPECT_EQ("testvalue", iterator->second);
453 hijack = true;
454 }
455 if (methods->QueryInterceptionHookPoint(
456 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
457 EchoRequest req;
458 auto* buffer = methods->GetSerializedSendMessage();
459 auto copied_buffer = *buffer;
460 EXPECT_TRUE(
461 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
462 .ok());
463 EXPECT_EQ(req.message(), "Hello");
464 }
465 if (methods->QueryInterceptionHookPoint(
466 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
467 // Got nothing to do here for now
468 }
469 if (methods->QueryInterceptionHookPoint(
470 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
471 auto* map = methods->GetRecvTrailingMetadata();
472 bool found = false;
473 // Check that we received the metadata as an echo
474 for (const auto& pair : *map) {
475 found = pair.first.starts_with("testkey") &&
476 pair.second.starts_with("testvalue");
477 if (found) break;
478 }
479 EXPECT_EQ(found, true);
480 auto* status = methods->GetRecvStatus();
481 EXPECT_EQ(status->ok(), true);
482 }
483 if (methods->QueryInterceptionHookPoint(
484 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
485 if (++count_ > 10) {
486 methods->FailHijackedRecvMessage();
487 }
488 EchoResponse* resp =
489 static_cast<EchoResponse*>(methods->GetRecvMessage());
490 resp->set_message("Hello");
491 }
492 if (methods->QueryInterceptionHookPoint(
493 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
494 // Only the last message will be a failure
495 EXPECT_FALSE(got_failed_message_);
496 got_failed_message_ = methods->GetRecvMessage() == nullptr;
497 }
498 if (methods->QueryInterceptionHookPoint(
499 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
500 auto* map = methods->GetRecvTrailingMetadata();
501 // insert the metadata that we want
502 EXPECT_EQ(map->size(), 0);
503 map->insert(std::make_pair("testkey", "testvalue"));
504 auto* status = methods->GetRecvStatus();
505 *status = Status(StatusCode::OK, "");
506 }
507 if (hijack) {
508 methods->Hijack();
509 } else {
510 methods->Proceed();
511 }
512 }
513
GotFailedMessage()514 static bool GotFailedMessage() { return got_failed_message_; }
515
516 private:
517 experimental::ClientRpcInfo* info_;
518 static bool got_failed_message_;
519 int count_ = 0;
520 };
521
522 bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
523
524 class ServerStreamingRpcHijackingInterceptorFactory
525 : public experimental::ClientInterceptorFactoryInterface {
526 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)527 experimental::Interceptor* CreateClientInterceptor(
528 experimental::ClientRpcInfo* info) override {
529 return new ServerStreamingRpcHijackingInterceptor(info);
530 }
531 };
532
533 class BidiStreamingRpcHijackingInterceptorFactory
534 : public experimental::ClientInterceptorFactoryInterface {
535 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)536 experimental::Interceptor* CreateClientInterceptor(
537 experimental::ClientRpcInfo* info) override {
538 return new BidiStreamingRpcHijackingInterceptor(info);
539 }
540 };
541
542 // The logging interceptor is for testing purposes only. It is used to verify
543 // that all the appropriate hook points are invoked for an RPC. The counts are
544 // reset each time a new object of LoggingInterceptor is created, so only a
545 // single RPC should be made on the channel before calling the Verify methods.
546 class LoggingInterceptor : public experimental::Interceptor {
547 public:
LoggingInterceptor(experimental::ClientRpcInfo *)548 explicit LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) {
549 pre_send_initial_metadata_ = false;
550 pre_send_message_count_ = 0;
551 pre_send_close_ = false;
552 post_recv_initial_metadata_ = false;
553 post_recv_message_count_ = 0;
554 post_recv_status_ = false;
555 }
556
Intercept(experimental::InterceptorBatchMethods * methods)557 void Intercept(experimental::InterceptorBatchMethods* methods) override {
558 if (methods->QueryInterceptionHookPoint(
559 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
560 auto* map = methods->GetSendInitialMetadata();
561 // Check that we can see the test metadata
562 ASSERT_EQ(map->size(), 1);
563 auto iterator = map->begin();
564 EXPECT_EQ("testkey", iterator->first);
565 EXPECT_EQ("testvalue", iterator->second);
566 ASSERT_FALSE(pre_send_initial_metadata_);
567 pre_send_initial_metadata_ = true;
568 }
569 if (methods->QueryInterceptionHookPoint(
570 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
571 EchoRequest req;
572 auto* send_msg = methods->GetSendMessage();
573 if (send_msg == nullptr) {
574 // We did not get the non-serialized form of the message. Get the
575 // serialized form.
576 auto* buffer = methods->GetSerializedSendMessage();
577 auto copied_buffer = *buffer;
578 EchoRequest req;
579 EXPECT_TRUE(
580 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
581 .ok());
582 EXPECT_EQ(req.message(), "Hello");
583 } else {
584 EXPECT_EQ(
585 static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
586 0u);
587 }
588 auto* buffer = methods->GetSerializedSendMessage();
589 auto copied_buffer = *buffer;
590 EXPECT_TRUE(
591 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
592 .ok());
593 EXPECT_TRUE(req.message().find("Hello") == 0u);
594 pre_send_message_count_++;
595 }
596 if (methods->QueryInterceptionHookPoint(
597 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
598 // Got nothing to do here for now
599 pre_send_close_ = true;
600 }
601 if (methods->QueryInterceptionHookPoint(
602 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
603 auto* map = methods->GetRecvInitialMetadata();
604 // Got nothing better to do here for now
605 EXPECT_EQ(map->size(), 0);
606 post_recv_initial_metadata_ = true;
607 }
608 if (methods->QueryInterceptionHookPoint(
609 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
610 EchoResponse* resp =
611 static_cast<EchoResponse*>(methods->GetRecvMessage());
612 if (resp != nullptr) {
613 EXPECT_TRUE(resp->message().find("Hello") == 0u);
614 post_recv_message_count_++;
615 }
616 }
617 if (methods->QueryInterceptionHookPoint(
618 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
619 auto* map = methods->GetRecvTrailingMetadata();
620 bool found = false;
621 // Check that we received the metadata as an echo
622 for (const auto& pair : *map) {
623 found = pair.first.starts_with("testkey") &&
624 pair.second.starts_with("testvalue");
625 if (found) break;
626 }
627 EXPECT_EQ(found, true);
628 auto* status = methods->GetRecvStatus();
629 EXPECT_EQ(status->ok(), true);
630 post_recv_status_ = true;
631 }
632 methods->Proceed();
633 }
634
VerifyCall(RPCType type)635 static void VerifyCall(RPCType type) {
636 switch (type) {
637 case RPCType::kSyncUnary:
638 case RPCType::kAsyncCQUnary:
639 VerifyUnaryCall();
640 break;
641 case RPCType::kSyncClientStreaming:
642 case RPCType::kAsyncCQClientStreaming:
643 VerifyClientStreamingCall();
644 break;
645 case RPCType::kSyncServerStreaming:
646 case RPCType::kAsyncCQServerStreaming:
647 VerifyServerStreamingCall();
648 break;
649 case RPCType::kSyncBidiStreaming:
650 case RPCType::kAsyncCQBidiStreaming:
651 VerifyBidiStreamingCall();
652 break;
653 }
654 }
655
VerifyCallCommon()656 static void VerifyCallCommon() {
657 EXPECT_TRUE(pre_send_initial_metadata_);
658 EXPECT_TRUE(pre_send_close_);
659 EXPECT_TRUE(post_recv_initial_metadata_);
660 EXPECT_TRUE(post_recv_status_);
661 }
662
VerifyUnaryCall()663 static void VerifyUnaryCall() {
664 VerifyCallCommon();
665 EXPECT_EQ(pre_send_message_count_, 1);
666 EXPECT_EQ(post_recv_message_count_, 1);
667 }
668
VerifyClientStreamingCall()669 static void VerifyClientStreamingCall() {
670 VerifyCallCommon();
671 EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
672 EXPECT_EQ(post_recv_message_count_, 1);
673 }
674
VerifyServerStreamingCall()675 static void VerifyServerStreamingCall() {
676 VerifyCallCommon();
677 EXPECT_EQ(pre_send_message_count_, 1);
678 EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
679 }
680
VerifyBidiStreamingCall()681 static void VerifyBidiStreamingCall() {
682 VerifyCallCommon();
683 EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
684 EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
685 }
686
687 private:
688 static bool pre_send_initial_metadata_;
689 static int pre_send_message_count_;
690 static bool pre_send_close_;
691 static bool post_recv_initial_metadata_;
692 static int post_recv_message_count_;
693 static bool post_recv_status_;
694 };
695
696 bool LoggingInterceptor::pre_send_initial_metadata_;
697 int LoggingInterceptor::pre_send_message_count_;
698 bool LoggingInterceptor::pre_send_close_;
699 bool LoggingInterceptor::post_recv_initial_metadata_;
700 int LoggingInterceptor::post_recv_message_count_;
701 bool LoggingInterceptor::post_recv_status_;
702
703 class LoggingInterceptorFactory
704 : public experimental::ClientInterceptorFactoryInterface {
705 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)706 experimental::Interceptor* CreateClientInterceptor(
707 experimental::ClientRpcInfo* info) override {
708 return new LoggingInterceptor(info);
709 }
710 };
711
712 class TestScenario {
713 public:
TestScenario(const ChannelType & channel_type,const RPCType & rpc_type)714 explicit TestScenario(const ChannelType& channel_type,
715 const RPCType& rpc_type)
716 : channel_type_(channel_type), rpc_type_(rpc_type) {}
717
channel_type() const718 ChannelType channel_type() const { return channel_type_; }
719
rpc_type() const720 RPCType rpc_type() const { return rpc_type_; }
721
722 private:
723 const ChannelType channel_type_;
724 const RPCType rpc_type_;
725 };
726
CreateTestScenarios()727 std::vector<TestScenario> CreateTestScenarios() {
728 std::vector<TestScenario> scenarios;
729 std::vector<RPCType> rpc_types;
730 rpc_types.emplace_back(RPCType::kSyncUnary);
731 rpc_types.emplace_back(RPCType::kSyncClientStreaming);
732 rpc_types.emplace_back(RPCType::kSyncServerStreaming);
733 rpc_types.emplace_back(RPCType::kSyncBidiStreaming);
734 rpc_types.emplace_back(RPCType::kAsyncCQUnary);
735 rpc_types.emplace_back(RPCType::kAsyncCQServerStreaming);
736 for (const auto& rpc_type : rpc_types) {
737 scenarios.emplace_back(ChannelType::kHttpChannel, rpc_type);
738 // TODO(yashykt): Maybe add support for non-posix sockets too
739 #ifdef GRPC_POSIX_SOCKET
740 scenarios.emplace_back(ChannelType::kFdChannel, rpc_type);
741 #endif // GRPC_POSIX_SOCKET
742 }
743 return scenarios;
744 }
745
746 class ParameterizedClientInterceptorsEnd2endTest
747 : public ::testing::TestWithParam<TestScenario> {
748 protected:
ParameterizedClientInterceptorsEnd2endTest()749 ParameterizedClientInterceptorsEnd2endTest() {
750 ServerBuilder builder;
751 builder.RegisterService(&service_);
752 if (GetParam().channel_type() == ChannelType::kHttpChannel) {
753 int port = grpc_pick_unused_port_or_die();
754 server_address_ = "localhost:" + std::to_string(port);
755 builder.AddListeningPort(server_address_, InsecureServerCredentials());
756 server_ = builder.BuildAndStart();
757 }
758 #ifdef GRPC_POSIX_SOCKET
759 else if (GetParam().channel_type() == ChannelType::kFdChannel) {
760 int flags;
761 CHECK_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv_), 0);
762 flags = fcntl(sv_[0], F_GETFL, 0);
763 CHECK_EQ(fcntl(sv_[0], F_SETFL, flags | O_NONBLOCK), 0);
764 flags = fcntl(sv_[1], F_GETFL, 0);
765 CHECK_EQ(fcntl(sv_[1], F_SETFL, flags | O_NONBLOCK), 0);
766 CHECK(grpc_set_socket_no_sigpipe_if_possible(sv_[0]) == absl::OkStatus());
767 CHECK(grpc_set_socket_no_sigpipe_if_possible(sv_[1]) == absl::OkStatus());
768 server_ = builder.BuildAndStart();
769 AddInsecureChannelFromFd(server_.get(), sv_[1]);
770 }
771 #endif // GRPC_POSIX_SOCKET
772 }
773
~ParameterizedClientInterceptorsEnd2endTest()774 ~ParameterizedClientInterceptorsEnd2endTest() override {
775 server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
776 }
777
CreateClientChannel(std::vector<std::unique_ptr<grpc::experimental::ClientInterceptorFactoryInterface>> creators)778 std::shared_ptr<grpc::Channel> CreateClientChannel(
779 std::vector<std::unique_ptr<
780 grpc::experimental::ClientInterceptorFactoryInterface>>
781 creators) {
782 if (GetParam().channel_type() == ChannelType::kHttpChannel) {
783 return experimental::CreateCustomChannelWithInterceptors(
784 server_address_, InsecureChannelCredentials(), ChannelArguments(),
785 std::move(creators));
786 }
787 #ifdef GRPC_POSIX_SOCKET
788 else if (GetParam().channel_type() == ChannelType::kFdChannel) {
789 return experimental::CreateCustomInsecureChannelWithInterceptorsFromFd(
790 "", sv_[0], ChannelArguments(), std::move(creators));
791 }
792 #endif // GRPC_POSIX_SOCKET
793 return nullptr;
794 }
795
SendRPC(const std::shared_ptr<Channel> & channel)796 void SendRPC(const std::shared_ptr<Channel>& channel) {
797 switch (GetParam().rpc_type()) {
798 case RPCType::kSyncUnary:
799 MakeCall(channel);
800 break;
801 case RPCType::kSyncClientStreaming:
802 MakeClientStreamingCall(channel);
803 break;
804 case RPCType::kSyncServerStreaming:
805 MakeServerStreamingCall(channel);
806 break;
807 case RPCType::kSyncBidiStreaming:
808 MakeBidiStreamingCall(channel);
809 break;
810 case RPCType::kAsyncCQUnary:
811 MakeAsyncCQCall(channel);
812 break;
813 case RPCType::kAsyncCQClientStreaming:
814 // TODO(yashykt) : Fill this out
815 break;
816 case RPCType::kAsyncCQServerStreaming:
817 MakeAsyncCQServerStreamingCall(channel);
818 break;
819 case RPCType::kAsyncCQBidiStreaming:
820 // TODO(yashykt) : Fill this out
821 break;
822 }
823 }
824
825 std::string server_address_;
826 int sv_[2];
827 EchoTestServiceStreamingImpl service_;
828 std::unique_ptr<Server> server_;
829 };
830
TEST_P(ParameterizedClientInterceptorsEnd2endTest,ClientInterceptorLoggingTest)831 TEST_P(ParameterizedClientInterceptorsEnd2endTest,
832 ClientInterceptorLoggingTest) {
833 ChannelArguments args;
834 PhonyInterceptor::Reset();
835 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
836 creators;
837 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
838 // Add 20 phony interceptors
839 for (auto i = 0; i < 20; i++) {
840 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
841 }
842 auto channel = CreateClientChannel(std::move(creators));
843 SendRPC(channel);
844 LoggingInterceptor::VerifyCall(GetParam().rpc_type());
845 // Make sure all 20 phony interceptors were run
846 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
847 }
848
849 INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
850 ParameterizedClientInterceptorsEnd2endTest,
851 ::testing::ValuesIn(CreateTestScenarios()));
852
853 class ClientInterceptorsEnd2endTest
854 : public ::testing::TestWithParam<TestScenario> {
855 protected:
ClientInterceptorsEnd2endTest()856 ClientInterceptorsEnd2endTest() {
857 int port = grpc_pick_unused_port_or_die();
858
859 ServerBuilder builder;
860 server_address_ = "localhost:" + std::to_string(port);
861 builder.AddListeningPort(server_address_, InsecureServerCredentials());
862 builder.RegisterService(&service_);
863 server_ = builder.BuildAndStart();
864 }
865
~ClientInterceptorsEnd2endTest()866 ~ClientInterceptorsEnd2endTest() override { server_->Shutdown(); }
867
868 std::string server_address_;
869 TestServiceImpl service_;
870 std::unique_ptr<Server> server_;
871 };
872
TEST_F(ClientInterceptorsEnd2endTest,LameChannelClientInterceptorHijackingTest)873 TEST_F(ClientInterceptorsEnd2endTest,
874 LameChannelClientInterceptorHijackingTest) {
875 ChannelArguments args;
876 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
877 creators;
878 creators.push_back(std::make_unique<HijackingInterceptorFactory>());
879 auto channel = experimental::CreateCustomChannelWithInterceptors(
880 server_address_, nullptr, args, std::move(creators));
881 MakeCall(channel);
882 }
883
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingTest)884 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
885 ChannelArguments args;
886 PhonyInterceptor::Reset();
887 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
888 creators;
889 // Add 20 phony interceptors before hijacking interceptor
890 creators.reserve(20);
891 for (auto i = 0; i < 20; i++) {
892 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
893 }
894 creators.push_back(std::make_unique<HijackingInterceptorFactory>());
895 // Add 20 phony interceptors after hijacking interceptor
896 for (auto i = 0; i < 20; i++) {
897 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
898 }
899 auto channel = experimental::CreateCustomChannelWithInterceptors(
900 server_address_, InsecureChannelCredentials(), args, std::move(creators));
901 MakeCall(channel);
902 // Make sure only 20 phony interceptors were run
903 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
904 }
905
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorLogThenHijackTest)906 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
907 ChannelArguments args;
908 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
909 creators;
910 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
911 creators.push_back(std::make_unique<HijackingInterceptorFactory>());
912 auto channel = experimental::CreateCustomChannelWithInterceptors(
913 server_address_, InsecureChannelCredentials(), args, std::move(creators));
914 MakeCall(channel);
915 LoggingInterceptor::VerifyUnaryCall();
916 }
917
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingMakesAnotherCallTest)918 TEST_F(ClientInterceptorsEnd2endTest,
919 ClientInterceptorHijackingMakesAnotherCallTest) {
920 ChannelArguments args;
921 PhonyInterceptor::Reset();
922 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
923 creators;
924 // Add 5 phony interceptors before hijacking interceptor
925 creators.reserve(5);
926 for (auto i = 0; i < 5; i++) {
927 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
928 }
929 creators.push_back(
930 std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
931 new HijackingInterceptorMakesAnotherCallFactory()));
932 // Add 7 phony interceptors after hijacking interceptor
933 for (auto i = 0; i < 7; i++) {
934 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
935 }
936 auto channel = server_->experimental().InProcessChannelWithInterceptors(
937 args, std::move(creators));
938
939 MakeCall(channel, StubOptions("TestSuffixForStats"));
940 // Make sure all interceptors were run once, since the hijacking interceptor
941 // makes an RPC on the intercepted channel
942 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 12);
943 }
944
945 class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
946 protected:
ClientInterceptorsCallbackEnd2endTest()947 ClientInterceptorsCallbackEnd2endTest() {
948 int port = grpc_pick_unused_port_or_die();
949
950 ServerBuilder builder;
951 server_address_ = "localhost:" + std::to_string(port);
952 builder.AddListeningPort(server_address_, InsecureServerCredentials());
953 builder.RegisterService(&service_);
954 server_ = builder.BuildAndStart();
955 }
956
~ClientInterceptorsCallbackEnd2endTest()957 ~ClientInterceptorsCallbackEnd2endTest() override { server_->Shutdown(); }
958
959 std::string server_address_;
960 TestServiceImpl service_;
961 std::unique_ptr<Server> server_;
962 };
963
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorLoggingTestWithCallback)964 TEST_F(ClientInterceptorsCallbackEnd2endTest,
965 ClientInterceptorLoggingTestWithCallback) {
966 ChannelArguments args;
967 PhonyInterceptor::Reset();
968 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
969 creators;
970 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
971 // Add 20 phony interceptors
972 for (auto i = 0; i < 20; i++) {
973 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
974 }
975 auto channel = server_->experimental().InProcessChannelWithInterceptors(
976 args, std::move(creators));
977 MakeCallbackCall(channel);
978 LoggingInterceptor::VerifyUnaryCall();
979 // Make sure all 20 phony interceptors were run
980 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
981 }
982
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorHijackingTestWithCallback)983 TEST_F(ClientInterceptorsCallbackEnd2endTest,
984 ClientInterceptorHijackingTestWithCallback) {
985 ChannelArguments args;
986 PhonyInterceptor::Reset();
987 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
988 creators;
989 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
990 // Add 20 phony interceptors
991 for (auto i = 0; i < 20; i++) {
992 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
993 }
994 creators.push_back(std::make_unique<HijackingInterceptorFactory>());
995 auto channel = server_->experimental().InProcessChannelWithInterceptors(
996 args, std::move(creators));
997 MakeCallbackCall(channel);
998 LoggingInterceptor::VerifyUnaryCall();
999 // Make sure all 20 phony interceptors were run
1000 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1001 }
1002
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorFactoryAllowsNullptrReturn)1003 TEST_F(ClientInterceptorsCallbackEnd2endTest,
1004 ClientInterceptorFactoryAllowsNullptrReturn) {
1005 ChannelArguments args;
1006 PhonyInterceptor::Reset();
1007 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1008 creators;
1009 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1010 // Add 20 phony interceptors and 20 null interceptors
1011 for (auto i = 0; i < 20; i++) {
1012 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1013 creators.push_back(std::make_unique<NullInterceptorFactory>());
1014 }
1015 auto channel = server_->experimental().InProcessChannelWithInterceptors(
1016 args, std::move(creators));
1017 MakeCallbackCall(channel);
1018 LoggingInterceptor::VerifyUnaryCall();
1019 // Make sure all 20 phony interceptors were run
1020 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1021 }
1022
1023 class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
1024 protected:
ClientInterceptorsStreamingEnd2endTest()1025 ClientInterceptorsStreamingEnd2endTest() {
1026 int port = grpc_pick_unused_port_or_die();
1027
1028 ServerBuilder builder;
1029 server_address_ = "localhost:" + std::to_string(port);
1030 builder.AddListeningPort(server_address_, InsecureServerCredentials());
1031 builder.RegisterService(&service_);
1032 server_ = builder.BuildAndStart();
1033 }
1034
~ClientInterceptorsStreamingEnd2endTest()1035 ~ClientInterceptorsStreamingEnd2endTest() override { server_->Shutdown(); }
1036
1037 std::string server_address_;
1038 EchoTestServiceStreamingImpl service_;
1039 std::unique_ptr<Server> server_;
1040 };
1041
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingTest)1042 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
1043 ChannelArguments args;
1044 PhonyInterceptor::Reset();
1045 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1046 creators;
1047 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1048 // Add 20 phony interceptors
1049 for (auto i = 0; i < 20; i++) {
1050 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1051 }
1052 auto channel = experimental::CreateCustomChannelWithInterceptors(
1053 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1054 MakeClientStreamingCall(channel);
1055 LoggingInterceptor::VerifyClientStreamingCall();
1056 // Make sure all 20 phony interceptors were run
1057 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1058 }
1059
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingTest)1060 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
1061 ChannelArguments args;
1062 PhonyInterceptor::Reset();
1063 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1064 creators;
1065 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1066 // Add 20 phony interceptors
1067 for (auto i = 0; i < 20; i++) {
1068 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1069 }
1070 auto channel = experimental::CreateCustomChannelWithInterceptors(
1071 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1072 MakeServerStreamingCall(channel);
1073 LoggingInterceptor::VerifyServerStreamingCall();
1074 // Make sure all 20 phony interceptors were run
1075 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1076 }
1077
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingHijackingTest)1078 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
1079 ChannelArguments args;
1080 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1081 creators;
1082 creators.push_back(
1083 std::make_unique<ClientStreamingRpcHijackingInterceptorFactory>());
1084 auto channel = experimental::CreateCustomChannelWithInterceptors(
1085 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1086
1087 auto stub = grpc::testing::EchoTestService::NewStub(
1088 channel, StubOptions("TestSuffixForStats"));
1089 ClientContext ctx;
1090 EchoRequest req;
1091 EchoResponse resp;
1092 req.mutable_param()->set_echo_metadata(true);
1093 req.set_message("Hello");
1094 string expected_resp;
1095 auto writer = stub->RequestStream(&ctx, &resp);
1096 for (int i = 0; i < 10; i++) {
1097 EXPECT_TRUE(writer->Write(req));
1098 expected_resp += "Hello";
1099 }
1100 // The interceptor will reject the 11th message
1101 writer->Write(req);
1102 Status s = writer->Finish();
1103 EXPECT_EQ(s.ok(), false);
1104 EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
1105 }
1106
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingHijackingTest)1107 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
1108 ChannelArguments args;
1109 PhonyInterceptor::Reset();
1110 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1111 creators;
1112 creators.push_back(
1113 std::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
1114 auto channel = experimental::CreateCustomChannelWithInterceptors(
1115 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1116 MakeServerStreamingCall(channel);
1117 EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1118 }
1119
TEST_F(ClientInterceptorsStreamingEnd2endTest,AsyncCQServerStreamingHijackingTest)1120 TEST_F(ClientInterceptorsStreamingEnd2endTest,
1121 AsyncCQServerStreamingHijackingTest) {
1122 ChannelArguments args;
1123 PhonyInterceptor::Reset();
1124 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1125 creators;
1126 creators.push_back(
1127 std::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
1128 auto channel = experimental::CreateCustomChannelWithInterceptors(
1129 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1130 MakeAsyncCQServerStreamingCall(channel);
1131 EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1132 }
1133
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingHijackingTest)1134 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
1135 ChannelArguments args;
1136 PhonyInterceptor::Reset();
1137 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1138 creators;
1139 creators.push_back(
1140 std::make_unique<BidiStreamingRpcHijackingInterceptorFactory>());
1141 auto channel = experimental::CreateCustomChannelWithInterceptors(
1142 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1143 MakeBidiStreamingCall(channel);
1144 }
1145
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingTest)1146 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
1147 ChannelArguments args;
1148 PhonyInterceptor::Reset();
1149 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1150 creators;
1151 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1152 // Add 20 phony interceptors
1153 for (auto i = 0; i < 20; i++) {
1154 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1155 }
1156 auto channel = experimental::CreateCustomChannelWithInterceptors(
1157 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1158 MakeBidiStreamingCall(channel);
1159 LoggingInterceptor::VerifyBidiStreamingCall();
1160 // Make sure all 20 phony interceptors were run
1161 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1162 }
1163
1164 class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
1165 protected:
ClientGlobalInterceptorEnd2endTest()1166 ClientGlobalInterceptorEnd2endTest() {
1167 int port = grpc_pick_unused_port_or_die();
1168
1169 ServerBuilder builder;
1170 server_address_ = "localhost:" + std::to_string(port);
1171 builder.AddListeningPort(server_address_, InsecureServerCredentials());
1172 builder.RegisterService(&service_);
1173 server_ = builder.BuildAndStart();
1174 }
1175
~ClientGlobalInterceptorEnd2endTest()1176 ~ClientGlobalInterceptorEnd2endTest() override { server_->Shutdown(); }
1177
1178 std::string server_address_;
1179 TestServiceImpl service_;
1180 std::unique_ptr<Server> server_;
1181 };
1182
TEST_F(ClientGlobalInterceptorEnd2endTest,PhonyGlobalInterceptor)1183 TEST_F(ClientGlobalInterceptorEnd2endTest, PhonyGlobalInterceptor) {
1184 // We should ideally be registering a global interceptor only once per
1185 // process, but for the purposes of testing, it should be fine to modify the
1186 // registered global interceptor when there are no ongoing gRPC operations
1187 PhonyInterceptorFactory global_factory;
1188 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1189 ChannelArguments args;
1190 PhonyInterceptor::Reset();
1191 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1192 creators;
1193 // Add 20 phony interceptors
1194 creators.reserve(20);
1195 for (auto i = 0; i < 20; i++) {
1196 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1197 }
1198 auto channel = experimental::CreateCustomChannelWithInterceptors(
1199 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1200 MakeCall(channel);
1201 // Make sure all 20 phony interceptors were run with the global interceptor
1202 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 21);
1203 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1204 }
1205
TEST_F(ClientGlobalInterceptorEnd2endTest,LoggingGlobalInterceptor)1206 TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
1207 // We should ideally be registering a global interceptor only once per
1208 // process, but for the purposes of testing, it should be fine to modify the
1209 // registered global interceptor when there are no ongoing gRPC operations
1210 LoggingInterceptorFactory global_factory;
1211 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1212 ChannelArguments args;
1213 PhonyInterceptor::Reset();
1214 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1215 creators;
1216 // Add 20 phony interceptors
1217 creators.reserve(20);
1218 for (auto i = 0; i < 20; i++) {
1219 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1220 }
1221 auto channel = experimental::CreateCustomChannelWithInterceptors(
1222 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1223 MakeCall(channel);
1224 LoggingInterceptor::VerifyUnaryCall();
1225 // Make sure all 20 phony interceptors were run
1226 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1227 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1228 }
1229
TEST_F(ClientGlobalInterceptorEnd2endTest,HijackingGlobalInterceptor)1230 TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
1231 // We should ideally be registering a global interceptor only once per
1232 // process, but for the purposes of testing, it should be fine to modify the
1233 // registered global interceptor when there are no ongoing gRPC operations
1234 HijackingInterceptorFactory global_factory;
1235 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1236 ChannelArguments args;
1237 PhonyInterceptor::Reset();
1238 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1239 creators;
1240 // Add 20 phony interceptors
1241 creators.reserve(20);
1242 for (auto i = 0; i < 20; i++) {
1243 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1244 }
1245 auto channel = experimental::CreateCustomChannelWithInterceptors(
1246 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1247 MakeCall(channel);
1248 // Make sure all 20 phony interceptors were run
1249 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1250 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1251 }
1252
1253 } // namespace
1254 } // namespace testing
1255 } // namespace grpc
1256
main(int argc,char ** argv)1257 int main(int argc, char** argv) {
1258 grpc::testing::TestEnvironment env(&argc, argv);
1259 ::testing::InitGoogleTest(&argc, argv);
1260 int ret = RUN_ALL_TESTS();
1261 // Make sure that gRPC shuts down cleanly
1262 CHECK(grpc_wait_until_shutdown(10));
1263 return ret;
1264 }
1265