• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpcpp/channel.h>
20 #include <grpcpp/client_context.h>
21 #include <grpcpp/create_channel.h>
22 #include <grpcpp/create_channel_posix.h>
23 #include <grpcpp/generic/generic_stub.h>
24 #include <grpcpp/impl/proto_utils.h>
25 #include <grpcpp/server.h>
26 #include <grpcpp/server_builder.h>
27 #include <grpcpp/server_context.h>
28 #include <grpcpp/server_posix.h>
29 #include <grpcpp/support/client_interceptor.h>
30 #include <gtest/gtest.h>
31 
32 #include <memory>
33 #include <vector>
34 
35 #include "absl/log/check.h"
36 #include "absl/memory/memory.h"
37 #include "src/core/lib/iomgr/port.h"
38 #include "src/proto/grpc/testing/echo.grpc.pb.h"
39 #include "test/core/test_util/port.h"
40 #include "test/core/test_util/test_config.h"
41 #include "test/cpp/end2end/interceptors_util.h"
42 #include "test/cpp/end2end/test_service_impl.h"
43 #include "test/cpp/util/byte_buffer_proto_helper.h"
44 #include "test/cpp/util/string_ref_helper.h"
45 
46 #ifdef GRPC_POSIX_SOCKET
47 #include <fcntl.h>
48 
49 #include "src/core/lib/iomgr/socket_utils_posix.h"
50 #endif  // GRPC_POSIX_SOCKET
51 
52 namespace grpc {
53 namespace testing {
54 namespace {
55 
56 enum class RPCType {
57   kSyncUnary,
58   kSyncClientStreaming,
59   kSyncServerStreaming,
60   kSyncBidiStreaming,
61   kAsyncCQUnary,
62   kAsyncCQClientStreaming,
63   kAsyncCQServerStreaming,
64   kAsyncCQBidiStreaming,
65 };
66 
67 enum class ChannelType {
68   kHttpChannel,
69   kFdChannel,
70 };
71 
72 // Hijacks Echo RPC and fills in the expected values
73 class HijackingInterceptor : public experimental::Interceptor {
74  public:
HijackingInterceptor(experimental::ClientRpcInfo * info)75   explicit HijackingInterceptor(experimental::ClientRpcInfo* info) {
76     info_ = info;
77     // Make sure it is the right method
78     EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
79     EXPECT_EQ(info->suffix_for_stats(), nullptr);
80     EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
81   }
82 
Intercept(experimental::InterceptorBatchMethods * methods)83   void Intercept(experimental::InterceptorBatchMethods* methods) override {
84     bool hijack = false;
85     if (methods->QueryInterceptionHookPoint(
86             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
87       auto* map = methods->GetSendInitialMetadata();
88       // Check that we can see the test metadata
89       ASSERT_EQ(map->size(), 1);
90       auto iterator = map->begin();
91       EXPECT_EQ("testkey", iterator->first);
92       EXPECT_EQ("testvalue", iterator->second);
93       hijack = true;
94     }
95     if (methods->QueryInterceptionHookPoint(
96             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
97       EchoRequest req;
98       auto* buffer = methods->GetSerializedSendMessage();
99       auto copied_buffer = *buffer;
100       EXPECT_TRUE(
101           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
102               .ok());
103       EXPECT_EQ(req.message(), "Hello");
104     }
105     if (methods->QueryInterceptionHookPoint(
106             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
107       // Got nothing to do here for now
108     }
109     if (methods->QueryInterceptionHookPoint(
110             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
111       auto* map = methods->GetRecvInitialMetadata();
112       // Got nothing better to do here for now
113       EXPECT_EQ(map->size(), 0);
114     }
115     if (methods->QueryInterceptionHookPoint(
116             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
117       EchoResponse* resp =
118           static_cast<EchoResponse*>(methods->GetRecvMessage());
119       // Check that we got the hijacked message, and re-insert the expected
120       // message
121       EXPECT_EQ(resp->message(), "Hello1");
122       resp->set_message("Hello");
123     }
124     if (methods->QueryInterceptionHookPoint(
125             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
126       auto* map = methods->GetRecvTrailingMetadata();
127       bool found = false;
128       // Check that we received the metadata as an echo
129       for (const auto& pair : *map) {
130         found = pair.first.starts_with("testkey") &&
131                 pair.second.starts_with("testvalue");
132         if (found) break;
133       }
134       EXPECT_EQ(found, true);
135       auto* status = methods->GetRecvStatus();
136       EXPECT_EQ(status->ok(), true);
137     }
138     if (methods->QueryInterceptionHookPoint(
139             experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
140       auto* map = methods->GetRecvInitialMetadata();
141       // Got nothing better to do here at the moment
142       EXPECT_EQ(map->size(), 0);
143     }
144     if (methods->QueryInterceptionHookPoint(
145             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
146       // Insert a different message than expected
147       EchoResponse* resp =
148           static_cast<EchoResponse*>(methods->GetRecvMessage());
149       resp->set_message("Hello1");
150     }
151     if (methods->QueryInterceptionHookPoint(
152             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
153       auto* map = methods->GetRecvTrailingMetadata();
154       // insert the metadata that we want
155       EXPECT_EQ(map->size(), 0);
156       map->insert(std::make_pair("testkey", "testvalue"));
157       auto* status = methods->GetRecvStatus();
158       *status = Status(StatusCode::OK, "");
159     }
160     if (hijack) {
161       methods->Hijack();
162     } else {
163       methods->Proceed();
164     }
165   }
166 
167  private:
168   experimental::ClientRpcInfo* info_;
169 };
170 
171 class HijackingInterceptorFactory
172     : public experimental::ClientInterceptorFactoryInterface {
173  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)174   experimental::Interceptor* CreateClientInterceptor(
175       experimental::ClientRpcInfo* info) override {
176     return new HijackingInterceptor(info);
177   }
178 };
179 
180 class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
181  public:
HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo * info)182   explicit HijackingInterceptorMakesAnotherCall(
183       experimental::ClientRpcInfo* info) {
184     info_ = info;
185     // Make sure it is the right method
186     EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
187     EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
188   }
189 
Intercept(experimental::InterceptorBatchMethods * methods)190   void Intercept(experimental::InterceptorBatchMethods* methods) override {
191     if (methods->QueryInterceptionHookPoint(
192             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
193       auto* map = methods->GetSendInitialMetadata();
194       // Check that we can see the test metadata
195       ASSERT_EQ(map->size(), 1);
196       auto iterator = map->begin();
197       EXPECT_EQ("testkey", iterator->first);
198       EXPECT_EQ("testvalue", iterator->second);
199       // Make a copy of the map
200       metadata_map_ = *map;
201     }
202     if (methods->QueryInterceptionHookPoint(
203             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
204       EchoRequest req;
205       auto* buffer = methods->GetSerializedSendMessage();
206       auto copied_buffer = *buffer;
207       EXPECT_TRUE(
208           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
209               .ok());
210       EXPECT_EQ(req.message(), "Hello");
211       req_ = req;
212       stub_ = grpc::testing::EchoTestService::NewStub(
213           methods->GetInterceptedChannel());
214       ctx_.AddMetadata(metadata_map_.begin()->first,
215                        metadata_map_.begin()->second);
216       stub_->async()->Echo(&ctx_, &req_, &resp_, [this, methods](Status s) {
217         EXPECT_EQ(s.ok(), true);
218         EXPECT_EQ(resp_.message(), "Hello");
219         methods->Hijack();
220       });
221       // This is a Unary RPC and we have got nothing interesting to do in the
222       // PRE_SEND_CLOSE interception hook point for this interceptor, so let's
223       // return here. (We do not want to call methods->Proceed(). When the new
224       // RPC returns, we will call methods->Hijack() instead.)
225       return;
226     }
227     if (methods->QueryInterceptionHookPoint(
228             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
229       // Got nothing to do here for now
230     }
231     if (methods->QueryInterceptionHookPoint(
232             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
233       auto* map = methods->GetRecvInitialMetadata();
234       // Got nothing better to do here for now
235       EXPECT_EQ(map->size(), 0);
236     }
237     if (methods->QueryInterceptionHookPoint(
238             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
239       EchoResponse* resp =
240           static_cast<EchoResponse*>(methods->GetRecvMessage());
241       // Check that we got the hijacked message, and re-insert the expected
242       // message
243       EXPECT_EQ(resp->message(), "Hello");
244     }
245     if (methods->QueryInterceptionHookPoint(
246             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
247       auto* map = methods->GetRecvTrailingMetadata();
248       bool found = false;
249       // Check that we received the metadata as an echo
250       for (const auto& pair : *map) {
251         found = pair.first.starts_with("testkey") &&
252                 pair.second.starts_with("testvalue");
253         if (found) break;
254       }
255       EXPECT_EQ(found, true);
256       auto* status = methods->GetRecvStatus();
257       EXPECT_EQ(status->ok(), true);
258     }
259     if (methods->QueryInterceptionHookPoint(
260             experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
261       auto* map = methods->GetRecvInitialMetadata();
262       // Got nothing better to do here at the moment
263       EXPECT_EQ(map->size(), 0);
264     }
265     if (methods->QueryInterceptionHookPoint(
266             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
267       // Insert a different message than expected
268       EchoResponse* resp =
269           static_cast<EchoResponse*>(methods->GetRecvMessage());
270       resp->set_message(resp_.message());
271     }
272     if (methods->QueryInterceptionHookPoint(
273             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
274       auto* map = methods->GetRecvTrailingMetadata();
275       // insert the metadata that we want
276       EXPECT_EQ(map->size(), 0);
277       map->insert(std::make_pair("testkey", "testvalue"));
278       auto* status = methods->GetRecvStatus();
279       *status = Status(StatusCode::OK, "");
280     }
281 
282     methods->Proceed();
283   }
284 
285  private:
286   experimental::ClientRpcInfo* info_;
287   std::multimap<std::string, std::string> metadata_map_;
288   ClientContext ctx_;
289   EchoRequest req_;
290   EchoResponse resp_;
291   std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
292 };
293 
294 class HijackingInterceptorMakesAnotherCallFactory
295     : public experimental::ClientInterceptorFactoryInterface {
296  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)297   experimental::Interceptor* CreateClientInterceptor(
298       experimental::ClientRpcInfo* info) override {
299     return new HijackingInterceptorMakesAnotherCall(info);
300   }
301 };
302 
303 class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
304  public:
BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)305   explicit BidiStreamingRpcHijackingInterceptor(
306       experimental::ClientRpcInfo* info) {
307     info_ = info;
308     EXPECT_EQ(info->suffix_for_stats(), nullptr);
309   }
310 
Intercept(experimental::InterceptorBatchMethods * methods)311   void Intercept(experimental::InterceptorBatchMethods* methods) override {
312     bool hijack = false;
313     if (methods->QueryInterceptionHookPoint(
314             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
315       CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
316       hijack = true;
317     }
318     if (methods->QueryInterceptionHookPoint(
319             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
320       EchoRequest req;
321       auto* buffer = methods->GetSerializedSendMessage();
322       auto copied_buffer = *buffer;
323       EXPECT_TRUE(
324           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
325               .ok());
326       EXPECT_EQ(req.message().find("Hello"), 0u);
327       msg = req.message();
328     }
329     if (methods->QueryInterceptionHookPoint(
330             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
331       // Got nothing to do here for now
332     }
333     if (methods->QueryInterceptionHookPoint(
334             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
335       CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
336                     "testvalue");
337       auto* status = methods->GetRecvStatus();
338       EXPECT_EQ(status->ok(), true);
339     }
340     if (methods->QueryInterceptionHookPoint(
341             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
342       EchoResponse* resp =
343           static_cast<EchoResponse*>(methods->GetRecvMessage());
344       resp->set_message(msg);
345     }
346     if (methods->QueryInterceptionHookPoint(
347             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
348       EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
349                     ->message()
350                     .find("Hello"),
351                 0u);
352     }
353     if (methods->QueryInterceptionHookPoint(
354             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
355       auto* map = methods->GetRecvTrailingMetadata();
356       // insert the metadata that we want
357       EXPECT_EQ(map->size(), 0);
358       map->insert(std::make_pair("testkey", "testvalue"));
359       auto* status = methods->GetRecvStatus();
360       *status = Status(StatusCode::OK, "");
361     }
362     if (hijack) {
363       methods->Hijack();
364     } else {
365       methods->Proceed();
366     }
367   }
368 
369  private:
370   experimental::ClientRpcInfo* info_;
371   std::string msg;
372 };
373 
374 class ClientStreamingRpcHijackingInterceptor
375     : public experimental::Interceptor {
376  public:
ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)377   explicit ClientStreamingRpcHijackingInterceptor(
378       experimental::ClientRpcInfo* info) {
379     info_ = info;
380     EXPECT_EQ(
381         strcmp("/grpc.testing.EchoTestService/RequestStream", info->method()),
382         0);
383     EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
384   }
Intercept(experimental::InterceptorBatchMethods * methods)385   void Intercept(experimental::InterceptorBatchMethods* methods) override {
386     bool hijack = false;
387     if (methods->QueryInterceptionHookPoint(
388             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
389       hijack = true;
390     }
391     if (methods->QueryInterceptionHookPoint(
392             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
393       if (++count_ > 10) {
394         methods->FailHijackedSendMessage();
395       }
396     }
397     if (methods->QueryInterceptionHookPoint(
398             experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
399       EXPECT_FALSE(got_failed_send_);
400       got_failed_send_ = !methods->GetSendMessageStatus();
401     }
402     if (methods->QueryInterceptionHookPoint(
403             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
404       auto* status = methods->GetRecvStatus();
405       *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
406     }
407     if (hijack) {
408       methods->Hijack();
409     } else {
410       methods->Proceed();
411     }
412   }
413 
GotFailedSend()414   static bool GotFailedSend() { return got_failed_send_; }
415 
416  private:
417   experimental::ClientRpcInfo* info_;
418   int count_ = 0;
419   static bool got_failed_send_;
420 };
421 
422 bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
423 
424 class ClientStreamingRpcHijackingInterceptorFactory
425     : public experimental::ClientInterceptorFactoryInterface {
426  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)427   experimental::Interceptor* CreateClientInterceptor(
428       experimental::ClientRpcInfo* info) override {
429     return new ClientStreamingRpcHijackingInterceptor(info);
430   }
431 };
432 
433 class ServerStreamingRpcHijackingInterceptor
434     : public experimental::Interceptor {
435  public:
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)436   explicit ServerStreamingRpcHijackingInterceptor(
437       experimental::ClientRpcInfo* info) {
438     info_ = info;
439     got_failed_message_ = false;
440     EXPECT_EQ(info->suffix_for_stats(), nullptr);
441   }
442 
Intercept(experimental::InterceptorBatchMethods * methods)443   void Intercept(experimental::InterceptorBatchMethods* methods) override {
444     bool hijack = false;
445     if (methods->QueryInterceptionHookPoint(
446             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
447       auto* map = methods->GetSendInitialMetadata();
448       // Check that we can see the test metadata
449       ASSERT_EQ(map->size(), 1);
450       auto iterator = map->begin();
451       EXPECT_EQ("testkey", iterator->first);
452       EXPECT_EQ("testvalue", iterator->second);
453       hijack = true;
454     }
455     if (methods->QueryInterceptionHookPoint(
456             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
457       EchoRequest req;
458       auto* buffer = methods->GetSerializedSendMessage();
459       auto copied_buffer = *buffer;
460       EXPECT_TRUE(
461           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
462               .ok());
463       EXPECT_EQ(req.message(), "Hello");
464     }
465     if (methods->QueryInterceptionHookPoint(
466             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
467       // Got nothing to do here for now
468     }
469     if (methods->QueryInterceptionHookPoint(
470             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
471       auto* map = methods->GetRecvTrailingMetadata();
472       bool found = false;
473       // Check that we received the metadata as an echo
474       for (const auto& pair : *map) {
475         found = pair.first.starts_with("testkey") &&
476                 pair.second.starts_with("testvalue");
477         if (found) break;
478       }
479       EXPECT_EQ(found, true);
480       auto* status = methods->GetRecvStatus();
481       EXPECT_EQ(status->ok(), true);
482     }
483     if (methods->QueryInterceptionHookPoint(
484             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
485       if (++count_ > 10) {
486         methods->FailHijackedRecvMessage();
487       }
488       EchoResponse* resp =
489           static_cast<EchoResponse*>(methods->GetRecvMessage());
490       resp->set_message("Hello");
491     }
492     if (methods->QueryInterceptionHookPoint(
493             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
494       // Only the last message will be a failure
495       EXPECT_FALSE(got_failed_message_);
496       got_failed_message_ = methods->GetRecvMessage() == nullptr;
497     }
498     if (methods->QueryInterceptionHookPoint(
499             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
500       auto* map = methods->GetRecvTrailingMetadata();
501       // insert the metadata that we want
502       EXPECT_EQ(map->size(), 0);
503       map->insert(std::make_pair("testkey", "testvalue"));
504       auto* status = methods->GetRecvStatus();
505       *status = Status(StatusCode::OK, "");
506     }
507     if (hijack) {
508       methods->Hijack();
509     } else {
510       methods->Proceed();
511     }
512   }
513 
GotFailedMessage()514   static bool GotFailedMessage() { return got_failed_message_; }
515 
516  private:
517   experimental::ClientRpcInfo* info_;
518   static bool got_failed_message_;
519   int count_ = 0;
520 };
521 
522 bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
523 
524 class ServerStreamingRpcHijackingInterceptorFactory
525     : public experimental::ClientInterceptorFactoryInterface {
526  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)527   experimental::Interceptor* CreateClientInterceptor(
528       experimental::ClientRpcInfo* info) override {
529     return new ServerStreamingRpcHijackingInterceptor(info);
530   }
531 };
532 
533 class BidiStreamingRpcHijackingInterceptorFactory
534     : public experimental::ClientInterceptorFactoryInterface {
535  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)536   experimental::Interceptor* CreateClientInterceptor(
537       experimental::ClientRpcInfo* info) override {
538     return new BidiStreamingRpcHijackingInterceptor(info);
539   }
540 };
541 
542 // The logging interceptor is for testing purposes only. It is used to verify
543 // that all the appropriate hook points are invoked for an RPC. The counts are
544 // reset each time a new object of LoggingInterceptor is created, so only a
545 // single RPC should be made on the channel before calling the Verify methods.
546 class LoggingInterceptor : public experimental::Interceptor {
547  public:
LoggingInterceptor(experimental::ClientRpcInfo *)548   explicit LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) {
549     pre_send_initial_metadata_ = false;
550     pre_send_message_count_ = 0;
551     pre_send_close_ = false;
552     post_recv_initial_metadata_ = false;
553     post_recv_message_count_ = 0;
554     post_recv_status_ = false;
555   }
556 
Intercept(experimental::InterceptorBatchMethods * methods)557   void Intercept(experimental::InterceptorBatchMethods* methods) override {
558     if (methods->QueryInterceptionHookPoint(
559             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
560       auto* map = methods->GetSendInitialMetadata();
561       // Check that we can see the test metadata
562       ASSERT_EQ(map->size(), 1);
563       auto iterator = map->begin();
564       EXPECT_EQ("testkey", iterator->first);
565       EXPECT_EQ("testvalue", iterator->second);
566       ASSERT_FALSE(pre_send_initial_metadata_);
567       pre_send_initial_metadata_ = true;
568     }
569     if (methods->QueryInterceptionHookPoint(
570             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
571       EchoRequest req;
572       auto* send_msg = methods->GetSendMessage();
573       if (send_msg == nullptr) {
574         // We did not get the non-serialized form of the message. Get the
575         // serialized form.
576         auto* buffer = methods->GetSerializedSendMessage();
577         auto copied_buffer = *buffer;
578         EchoRequest req;
579         EXPECT_TRUE(
580             SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
581                 .ok());
582         EXPECT_EQ(req.message(), "Hello");
583       } else {
584         EXPECT_EQ(
585             static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
586             0u);
587       }
588       auto* buffer = methods->GetSerializedSendMessage();
589       auto copied_buffer = *buffer;
590       EXPECT_TRUE(
591           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
592               .ok());
593       EXPECT_TRUE(req.message().find("Hello") == 0u);
594       pre_send_message_count_++;
595     }
596     if (methods->QueryInterceptionHookPoint(
597             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
598       // Got nothing to do here for now
599       pre_send_close_ = true;
600     }
601     if (methods->QueryInterceptionHookPoint(
602             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
603       auto* map = methods->GetRecvInitialMetadata();
604       // Got nothing better to do here for now
605       EXPECT_EQ(map->size(), 0);
606       post_recv_initial_metadata_ = true;
607     }
608     if (methods->QueryInterceptionHookPoint(
609             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
610       EchoResponse* resp =
611           static_cast<EchoResponse*>(methods->GetRecvMessage());
612       if (resp != nullptr) {
613         EXPECT_TRUE(resp->message().find("Hello") == 0u);
614         post_recv_message_count_++;
615       }
616     }
617     if (methods->QueryInterceptionHookPoint(
618             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
619       auto* map = methods->GetRecvTrailingMetadata();
620       bool found = false;
621       // Check that we received the metadata as an echo
622       for (const auto& pair : *map) {
623         found = pair.first.starts_with("testkey") &&
624                 pair.second.starts_with("testvalue");
625         if (found) break;
626       }
627       EXPECT_EQ(found, true);
628       auto* status = methods->GetRecvStatus();
629       EXPECT_EQ(status->ok(), true);
630       post_recv_status_ = true;
631     }
632     methods->Proceed();
633   }
634 
VerifyCall(RPCType type)635   static void VerifyCall(RPCType type) {
636     switch (type) {
637       case RPCType::kSyncUnary:
638       case RPCType::kAsyncCQUnary:
639         VerifyUnaryCall();
640         break;
641       case RPCType::kSyncClientStreaming:
642       case RPCType::kAsyncCQClientStreaming:
643         VerifyClientStreamingCall();
644         break;
645       case RPCType::kSyncServerStreaming:
646       case RPCType::kAsyncCQServerStreaming:
647         VerifyServerStreamingCall();
648         break;
649       case RPCType::kSyncBidiStreaming:
650       case RPCType::kAsyncCQBidiStreaming:
651         VerifyBidiStreamingCall();
652         break;
653     }
654   }
655 
VerifyCallCommon()656   static void VerifyCallCommon() {
657     EXPECT_TRUE(pre_send_initial_metadata_);
658     EXPECT_TRUE(pre_send_close_);
659     EXPECT_TRUE(post_recv_initial_metadata_);
660     EXPECT_TRUE(post_recv_status_);
661   }
662 
VerifyUnaryCall()663   static void VerifyUnaryCall() {
664     VerifyCallCommon();
665     EXPECT_EQ(pre_send_message_count_, 1);
666     EXPECT_EQ(post_recv_message_count_, 1);
667   }
668 
VerifyClientStreamingCall()669   static void VerifyClientStreamingCall() {
670     VerifyCallCommon();
671     EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
672     EXPECT_EQ(post_recv_message_count_, 1);
673   }
674 
VerifyServerStreamingCall()675   static void VerifyServerStreamingCall() {
676     VerifyCallCommon();
677     EXPECT_EQ(pre_send_message_count_, 1);
678     EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
679   }
680 
VerifyBidiStreamingCall()681   static void VerifyBidiStreamingCall() {
682     VerifyCallCommon();
683     EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
684     EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
685   }
686 
687  private:
688   static bool pre_send_initial_metadata_;
689   static int pre_send_message_count_;
690   static bool pre_send_close_;
691   static bool post_recv_initial_metadata_;
692   static int post_recv_message_count_;
693   static bool post_recv_status_;
694 };
695 
696 bool LoggingInterceptor::pre_send_initial_metadata_;
697 int LoggingInterceptor::pre_send_message_count_;
698 bool LoggingInterceptor::pre_send_close_;
699 bool LoggingInterceptor::post_recv_initial_metadata_;
700 int LoggingInterceptor::post_recv_message_count_;
701 bool LoggingInterceptor::post_recv_status_;
702 
703 class LoggingInterceptorFactory
704     : public experimental::ClientInterceptorFactoryInterface {
705  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)706   experimental::Interceptor* CreateClientInterceptor(
707       experimental::ClientRpcInfo* info) override {
708     return new LoggingInterceptor(info);
709   }
710 };
711 
712 class TestScenario {
713  public:
TestScenario(const ChannelType & channel_type,const RPCType & rpc_type)714   explicit TestScenario(const ChannelType& channel_type,
715                         const RPCType& rpc_type)
716       : channel_type_(channel_type), rpc_type_(rpc_type) {}
717 
channel_type() const718   ChannelType channel_type() const { return channel_type_; }
719 
rpc_type() const720   RPCType rpc_type() const { return rpc_type_; }
721 
722  private:
723   const ChannelType channel_type_;
724   const RPCType rpc_type_;
725 };
726 
CreateTestScenarios()727 std::vector<TestScenario> CreateTestScenarios() {
728   std::vector<TestScenario> scenarios;
729   std::vector<RPCType> rpc_types;
730   rpc_types.emplace_back(RPCType::kSyncUnary);
731   rpc_types.emplace_back(RPCType::kSyncClientStreaming);
732   rpc_types.emplace_back(RPCType::kSyncServerStreaming);
733   rpc_types.emplace_back(RPCType::kSyncBidiStreaming);
734   rpc_types.emplace_back(RPCType::kAsyncCQUnary);
735   rpc_types.emplace_back(RPCType::kAsyncCQServerStreaming);
736   for (const auto& rpc_type : rpc_types) {
737     scenarios.emplace_back(ChannelType::kHttpChannel, rpc_type);
738 // TODO(yashykt): Maybe add support for non-posix sockets too
739 #ifdef GRPC_POSIX_SOCKET
740     scenarios.emplace_back(ChannelType::kFdChannel, rpc_type);
741 #endif  // GRPC_POSIX_SOCKET
742   }
743   return scenarios;
744 }
745 
746 class ParameterizedClientInterceptorsEnd2endTest
747     : public ::testing::TestWithParam<TestScenario> {
748  protected:
ParameterizedClientInterceptorsEnd2endTest()749   ParameterizedClientInterceptorsEnd2endTest() {
750     ServerBuilder builder;
751     builder.RegisterService(&service_);
752     if (GetParam().channel_type() == ChannelType::kHttpChannel) {
753       int port = grpc_pick_unused_port_or_die();
754       server_address_ = "localhost:" + std::to_string(port);
755       builder.AddListeningPort(server_address_, InsecureServerCredentials());
756       server_ = builder.BuildAndStart();
757     }
758 #ifdef GRPC_POSIX_SOCKET
759     else if (GetParam().channel_type() == ChannelType::kFdChannel) {
760       int flags;
761       CHECK_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv_), 0);
762       flags = fcntl(sv_[0], F_GETFL, 0);
763       CHECK_EQ(fcntl(sv_[0], F_SETFL, flags | O_NONBLOCK), 0);
764       flags = fcntl(sv_[1], F_GETFL, 0);
765       CHECK_EQ(fcntl(sv_[1], F_SETFL, flags | O_NONBLOCK), 0);
766       CHECK(grpc_set_socket_no_sigpipe_if_possible(sv_[0]) == absl::OkStatus());
767       CHECK(grpc_set_socket_no_sigpipe_if_possible(sv_[1]) == absl::OkStatus());
768       server_ = builder.BuildAndStart();
769       AddInsecureChannelFromFd(server_.get(), sv_[1]);
770     }
771 #endif  // GRPC_POSIX_SOCKET
772   }
773 
~ParameterizedClientInterceptorsEnd2endTest()774   ~ParameterizedClientInterceptorsEnd2endTest() override {
775     server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
776   }
777 
CreateClientChannel(std::vector<std::unique_ptr<grpc::experimental::ClientInterceptorFactoryInterface>> creators)778   std::shared_ptr<grpc::Channel> CreateClientChannel(
779       std::vector<std::unique_ptr<
780           grpc::experimental::ClientInterceptorFactoryInterface>>
781           creators) {
782     if (GetParam().channel_type() == ChannelType::kHttpChannel) {
783       return experimental::CreateCustomChannelWithInterceptors(
784           server_address_, InsecureChannelCredentials(), ChannelArguments(),
785           std::move(creators));
786     }
787 #ifdef GRPC_POSIX_SOCKET
788     else if (GetParam().channel_type() == ChannelType::kFdChannel) {
789       return experimental::CreateCustomInsecureChannelWithInterceptorsFromFd(
790           "", sv_[0], ChannelArguments(), std::move(creators));
791     }
792 #endif  // GRPC_POSIX_SOCKET
793     return nullptr;
794   }
795 
SendRPC(const std::shared_ptr<Channel> & channel)796   void SendRPC(const std::shared_ptr<Channel>& channel) {
797     switch (GetParam().rpc_type()) {
798       case RPCType::kSyncUnary:
799         MakeCall(channel);
800         break;
801       case RPCType::kSyncClientStreaming:
802         MakeClientStreamingCall(channel);
803         break;
804       case RPCType::kSyncServerStreaming:
805         MakeServerStreamingCall(channel);
806         break;
807       case RPCType::kSyncBidiStreaming:
808         MakeBidiStreamingCall(channel);
809         break;
810       case RPCType::kAsyncCQUnary:
811         MakeAsyncCQCall(channel);
812         break;
813       case RPCType::kAsyncCQClientStreaming:
814         // TODO(yashykt) : Fill this out
815         break;
816       case RPCType::kAsyncCQServerStreaming:
817         MakeAsyncCQServerStreamingCall(channel);
818         break;
819       case RPCType::kAsyncCQBidiStreaming:
820         // TODO(yashykt) : Fill this out
821         break;
822     }
823   }
824 
825   std::string server_address_;
826   int sv_[2];
827   EchoTestServiceStreamingImpl service_;
828   std::unique_ptr<Server> server_;
829 };
830 
TEST_P(ParameterizedClientInterceptorsEnd2endTest,ClientInterceptorLoggingTest)831 TEST_P(ParameterizedClientInterceptorsEnd2endTest,
832        ClientInterceptorLoggingTest) {
833   ChannelArguments args;
834   PhonyInterceptor::Reset();
835   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
836       creators;
837   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
838   // Add 20 phony interceptors
839   for (auto i = 0; i < 20; i++) {
840     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
841   }
842   auto channel = CreateClientChannel(std::move(creators));
843   SendRPC(channel);
844   LoggingInterceptor::VerifyCall(GetParam().rpc_type());
845   // Make sure all 20 phony interceptors were run
846   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
847 }
848 
849 INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
850                          ParameterizedClientInterceptorsEnd2endTest,
851                          ::testing::ValuesIn(CreateTestScenarios()));
852 
853 class ClientInterceptorsEnd2endTest
854     : public ::testing::TestWithParam<TestScenario> {
855  protected:
ClientInterceptorsEnd2endTest()856   ClientInterceptorsEnd2endTest() {
857     int port = grpc_pick_unused_port_or_die();
858 
859     ServerBuilder builder;
860     server_address_ = "localhost:" + std::to_string(port);
861     builder.AddListeningPort(server_address_, InsecureServerCredentials());
862     builder.RegisterService(&service_);
863     server_ = builder.BuildAndStart();
864   }
865 
~ClientInterceptorsEnd2endTest()866   ~ClientInterceptorsEnd2endTest() override { server_->Shutdown(); }
867 
868   std::string server_address_;
869   TestServiceImpl service_;
870   std::unique_ptr<Server> server_;
871 };
872 
TEST_F(ClientInterceptorsEnd2endTest,LameChannelClientInterceptorHijackingTest)873 TEST_F(ClientInterceptorsEnd2endTest,
874        LameChannelClientInterceptorHijackingTest) {
875   ChannelArguments args;
876   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
877       creators;
878   creators.push_back(std::make_unique<HijackingInterceptorFactory>());
879   auto channel = experimental::CreateCustomChannelWithInterceptors(
880       server_address_, nullptr, args, std::move(creators));
881   MakeCall(channel);
882 }
883 
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingTest)884 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
885   ChannelArguments args;
886   PhonyInterceptor::Reset();
887   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
888       creators;
889   // Add 20 phony interceptors before hijacking interceptor
890   creators.reserve(20);
891   for (auto i = 0; i < 20; i++) {
892     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
893   }
894   creators.push_back(std::make_unique<HijackingInterceptorFactory>());
895   // Add 20 phony interceptors after hijacking interceptor
896   for (auto i = 0; i < 20; i++) {
897     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
898   }
899   auto channel = experimental::CreateCustomChannelWithInterceptors(
900       server_address_, InsecureChannelCredentials(), args, std::move(creators));
901   MakeCall(channel);
902   // Make sure only 20 phony interceptors were run
903   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
904 }
905 
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorLogThenHijackTest)906 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
907   ChannelArguments args;
908   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
909       creators;
910   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
911   creators.push_back(std::make_unique<HijackingInterceptorFactory>());
912   auto channel = experimental::CreateCustomChannelWithInterceptors(
913       server_address_, InsecureChannelCredentials(), args, std::move(creators));
914   MakeCall(channel);
915   LoggingInterceptor::VerifyUnaryCall();
916 }
917 
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingMakesAnotherCallTest)918 TEST_F(ClientInterceptorsEnd2endTest,
919        ClientInterceptorHijackingMakesAnotherCallTest) {
920   ChannelArguments args;
921   PhonyInterceptor::Reset();
922   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
923       creators;
924   // Add 5 phony interceptors before hijacking interceptor
925   creators.reserve(5);
926   for (auto i = 0; i < 5; i++) {
927     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
928   }
929   creators.push_back(
930       std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
931           new HijackingInterceptorMakesAnotherCallFactory()));
932   // Add 7 phony interceptors after hijacking interceptor
933   for (auto i = 0; i < 7; i++) {
934     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
935   }
936   auto channel = server_->experimental().InProcessChannelWithInterceptors(
937       args, std::move(creators));
938 
939   MakeCall(channel, StubOptions("TestSuffixForStats"));
940   // Make sure all interceptors were run once, since the hijacking interceptor
941   // makes an RPC on the intercepted channel
942   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 12);
943 }
944 
945 class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
946  protected:
ClientInterceptorsCallbackEnd2endTest()947   ClientInterceptorsCallbackEnd2endTest() {
948     int port = grpc_pick_unused_port_or_die();
949 
950     ServerBuilder builder;
951     server_address_ = "localhost:" + std::to_string(port);
952     builder.AddListeningPort(server_address_, InsecureServerCredentials());
953     builder.RegisterService(&service_);
954     server_ = builder.BuildAndStart();
955   }
956 
~ClientInterceptorsCallbackEnd2endTest()957   ~ClientInterceptorsCallbackEnd2endTest() override { server_->Shutdown(); }
958 
959   std::string server_address_;
960   TestServiceImpl service_;
961   std::unique_ptr<Server> server_;
962 };
963 
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorLoggingTestWithCallback)964 TEST_F(ClientInterceptorsCallbackEnd2endTest,
965        ClientInterceptorLoggingTestWithCallback) {
966   ChannelArguments args;
967   PhonyInterceptor::Reset();
968   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
969       creators;
970   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
971   // Add 20 phony interceptors
972   for (auto i = 0; i < 20; i++) {
973     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
974   }
975   auto channel = server_->experimental().InProcessChannelWithInterceptors(
976       args, std::move(creators));
977   MakeCallbackCall(channel);
978   LoggingInterceptor::VerifyUnaryCall();
979   // Make sure all 20 phony interceptors were run
980   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
981 }
982 
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorHijackingTestWithCallback)983 TEST_F(ClientInterceptorsCallbackEnd2endTest,
984        ClientInterceptorHijackingTestWithCallback) {
985   ChannelArguments args;
986   PhonyInterceptor::Reset();
987   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
988       creators;
989   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
990   // Add 20 phony interceptors
991   for (auto i = 0; i < 20; i++) {
992     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
993   }
994   creators.push_back(std::make_unique<HijackingInterceptorFactory>());
995   auto channel = server_->experimental().InProcessChannelWithInterceptors(
996       args, std::move(creators));
997   MakeCallbackCall(channel);
998   LoggingInterceptor::VerifyUnaryCall();
999   // Make sure all 20 phony interceptors were run
1000   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1001 }
1002 
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorFactoryAllowsNullptrReturn)1003 TEST_F(ClientInterceptorsCallbackEnd2endTest,
1004        ClientInterceptorFactoryAllowsNullptrReturn) {
1005   ChannelArguments args;
1006   PhonyInterceptor::Reset();
1007   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1008       creators;
1009   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1010   // Add 20 phony interceptors and 20 null interceptors
1011   for (auto i = 0; i < 20; i++) {
1012     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1013     creators.push_back(std::make_unique<NullInterceptorFactory>());
1014   }
1015   auto channel = server_->experimental().InProcessChannelWithInterceptors(
1016       args, std::move(creators));
1017   MakeCallbackCall(channel);
1018   LoggingInterceptor::VerifyUnaryCall();
1019   // Make sure all 20 phony interceptors were run
1020   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1021 }
1022 
1023 class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
1024  protected:
ClientInterceptorsStreamingEnd2endTest()1025   ClientInterceptorsStreamingEnd2endTest() {
1026     int port = grpc_pick_unused_port_or_die();
1027 
1028     ServerBuilder builder;
1029     server_address_ = "localhost:" + std::to_string(port);
1030     builder.AddListeningPort(server_address_, InsecureServerCredentials());
1031     builder.RegisterService(&service_);
1032     server_ = builder.BuildAndStart();
1033   }
1034 
~ClientInterceptorsStreamingEnd2endTest()1035   ~ClientInterceptorsStreamingEnd2endTest() override { server_->Shutdown(); }
1036 
1037   std::string server_address_;
1038   EchoTestServiceStreamingImpl service_;
1039   std::unique_ptr<Server> server_;
1040 };
1041 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingTest)1042 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
1043   ChannelArguments args;
1044   PhonyInterceptor::Reset();
1045   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1046       creators;
1047   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1048   // Add 20 phony interceptors
1049   for (auto i = 0; i < 20; i++) {
1050     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1051   }
1052   auto channel = experimental::CreateCustomChannelWithInterceptors(
1053       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1054   MakeClientStreamingCall(channel);
1055   LoggingInterceptor::VerifyClientStreamingCall();
1056   // Make sure all 20 phony interceptors were run
1057   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1058 }
1059 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingTest)1060 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
1061   ChannelArguments args;
1062   PhonyInterceptor::Reset();
1063   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1064       creators;
1065   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1066   // Add 20 phony interceptors
1067   for (auto i = 0; i < 20; i++) {
1068     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1069   }
1070   auto channel = experimental::CreateCustomChannelWithInterceptors(
1071       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1072   MakeServerStreamingCall(channel);
1073   LoggingInterceptor::VerifyServerStreamingCall();
1074   // Make sure all 20 phony interceptors were run
1075   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1076 }
1077 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingHijackingTest)1078 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
1079   ChannelArguments args;
1080   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1081       creators;
1082   creators.push_back(
1083       std::make_unique<ClientStreamingRpcHijackingInterceptorFactory>());
1084   auto channel = experimental::CreateCustomChannelWithInterceptors(
1085       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1086 
1087   auto stub = grpc::testing::EchoTestService::NewStub(
1088       channel, StubOptions("TestSuffixForStats"));
1089   ClientContext ctx;
1090   EchoRequest req;
1091   EchoResponse resp;
1092   req.mutable_param()->set_echo_metadata(true);
1093   req.set_message("Hello");
1094   string expected_resp;
1095   auto writer = stub->RequestStream(&ctx, &resp);
1096   for (int i = 0; i < 10; i++) {
1097     EXPECT_TRUE(writer->Write(req));
1098     expected_resp += "Hello";
1099   }
1100   // The interceptor will reject the 11th message
1101   writer->Write(req);
1102   Status s = writer->Finish();
1103   EXPECT_EQ(s.ok(), false);
1104   EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
1105 }
1106 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingHijackingTest)1107 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
1108   ChannelArguments args;
1109   PhonyInterceptor::Reset();
1110   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1111       creators;
1112   creators.push_back(
1113       std::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
1114   auto channel = experimental::CreateCustomChannelWithInterceptors(
1115       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1116   MakeServerStreamingCall(channel);
1117   EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1118 }
1119 
TEST_F(ClientInterceptorsStreamingEnd2endTest,AsyncCQServerStreamingHijackingTest)1120 TEST_F(ClientInterceptorsStreamingEnd2endTest,
1121        AsyncCQServerStreamingHijackingTest) {
1122   ChannelArguments args;
1123   PhonyInterceptor::Reset();
1124   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1125       creators;
1126   creators.push_back(
1127       std::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
1128   auto channel = experimental::CreateCustomChannelWithInterceptors(
1129       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1130   MakeAsyncCQServerStreamingCall(channel);
1131   EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1132 }
1133 
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingHijackingTest)1134 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
1135   ChannelArguments args;
1136   PhonyInterceptor::Reset();
1137   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1138       creators;
1139   creators.push_back(
1140       std::make_unique<BidiStreamingRpcHijackingInterceptorFactory>());
1141   auto channel = experimental::CreateCustomChannelWithInterceptors(
1142       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1143   MakeBidiStreamingCall(channel);
1144 }
1145 
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingTest)1146 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
1147   ChannelArguments args;
1148   PhonyInterceptor::Reset();
1149   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1150       creators;
1151   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1152   // Add 20 phony interceptors
1153   for (auto i = 0; i < 20; i++) {
1154     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1155   }
1156   auto channel = experimental::CreateCustomChannelWithInterceptors(
1157       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1158   MakeBidiStreamingCall(channel);
1159   LoggingInterceptor::VerifyBidiStreamingCall();
1160   // Make sure all 20 phony interceptors were run
1161   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1162 }
1163 
1164 class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
1165  protected:
ClientGlobalInterceptorEnd2endTest()1166   ClientGlobalInterceptorEnd2endTest() {
1167     int port = grpc_pick_unused_port_or_die();
1168 
1169     ServerBuilder builder;
1170     server_address_ = "localhost:" + std::to_string(port);
1171     builder.AddListeningPort(server_address_, InsecureServerCredentials());
1172     builder.RegisterService(&service_);
1173     server_ = builder.BuildAndStart();
1174   }
1175 
~ClientGlobalInterceptorEnd2endTest()1176   ~ClientGlobalInterceptorEnd2endTest() override { server_->Shutdown(); }
1177 
1178   std::string server_address_;
1179   TestServiceImpl service_;
1180   std::unique_ptr<Server> server_;
1181 };
1182 
TEST_F(ClientGlobalInterceptorEnd2endTest,PhonyGlobalInterceptor)1183 TEST_F(ClientGlobalInterceptorEnd2endTest, PhonyGlobalInterceptor) {
1184   // We should ideally be registering a global interceptor only once per
1185   // process, but for the purposes of testing, it should be fine to modify the
1186   // registered global interceptor when there are no ongoing gRPC operations
1187   PhonyInterceptorFactory global_factory;
1188   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1189   ChannelArguments args;
1190   PhonyInterceptor::Reset();
1191   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1192       creators;
1193   // Add 20 phony interceptors
1194   creators.reserve(20);
1195   for (auto i = 0; i < 20; i++) {
1196     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1197   }
1198   auto channel = experimental::CreateCustomChannelWithInterceptors(
1199       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1200   MakeCall(channel);
1201   // Make sure all 20 phony interceptors were run with the global interceptor
1202   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 21);
1203   experimental::TestOnlyResetGlobalClientInterceptorFactory();
1204 }
1205 
TEST_F(ClientGlobalInterceptorEnd2endTest,LoggingGlobalInterceptor)1206 TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
1207   // We should ideally be registering a global interceptor only once per
1208   // process, but for the purposes of testing, it should be fine to modify the
1209   // registered global interceptor when there are no ongoing gRPC operations
1210   LoggingInterceptorFactory global_factory;
1211   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1212   ChannelArguments args;
1213   PhonyInterceptor::Reset();
1214   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1215       creators;
1216   // Add 20 phony interceptors
1217   creators.reserve(20);
1218   for (auto i = 0; i < 20; i++) {
1219     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1220   }
1221   auto channel = experimental::CreateCustomChannelWithInterceptors(
1222       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1223   MakeCall(channel);
1224   LoggingInterceptor::VerifyUnaryCall();
1225   // Make sure all 20 phony interceptors were run
1226   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1227   experimental::TestOnlyResetGlobalClientInterceptorFactory();
1228 }
1229 
TEST_F(ClientGlobalInterceptorEnd2endTest,HijackingGlobalInterceptor)1230 TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
1231   // We should ideally be registering a global interceptor only once per
1232   // process, but for the purposes of testing, it should be fine to modify the
1233   // registered global interceptor when there are no ongoing gRPC operations
1234   HijackingInterceptorFactory global_factory;
1235   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1236   ChannelArguments args;
1237   PhonyInterceptor::Reset();
1238   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1239       creators;
1240   // Add 20 phony interceptors
1241   creators.reserve(20);
1242   for (auto i = 0; i < 20; i++) {
1243     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1244   }
1245   auto channel = experimental::CreateCustomChannelWithInterceptors(
1246       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1247   MakeCall(channel);
1248   // Make sure all 20 phony interceptors were run
1249   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1250   experimental::TestOnlyResetGlobalClientInterceptorFactory();
1251 }
1252 
1253 }  // namespace
1254 }  // namespace testing
1255 }  // namespace grpc
1256 
main(int argc,char ** argv)1257 int main(int argc, char** argv) {
1258   grpc::testing::TestEnvironment env(&argc, argv);
1259   ::testing::InitGoogleTest(&argc, argv);
1260   int ret = RUN_ALL_TESTS();
1261   // Make sure that gRPC shuts down cleanly
1262   CHECK(grpc_wait_until_shutdown(10));
1263   return ret;
1264 }
1265