• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <memory>
20 #include <vector>
21 
22 #include <grpcpp/channel.h>
23 #include <grpcpp/client_context.h>
24 #include <grpcpp/create_channel.h>
25 #include <grpcpp/generic/generic_stub.h>
26 #include <grpcpp/impl/codegen/proto_utils.h>
27 #include <grpcpp/server.h>
28 #include <grpcpp/server_builder.h>
29 #include <grpcpp/server_context.h>
30 #include <grpcpp/support/client_interceptor.h>
31 
32 #include "src/proto/grpc/testing/echo.grpc.pb.h"
33 #include "test/core/util/port.h"
34 #include "test/core/util/test_config.h"
35 #include "test/cpp/end2end/interceptors_util.h"
36 #include "test/cpp/end2end/test_service_impl.h"
37 #include "test/cpp/util/byte_buffer_proto_helper.h"
38 #include "test/cpp/util/string_ref_helper.h"
39 
40 #include <gtest/gtest.h>
41 
42 namespace grpc {
43 namespace testing {
44 namespace {
45 
46 enum class RPCType {
47   kSyncUnary,
48   kSyncClientStreaming,
49   kSyncServerStreaming,
50   kSyncBidiStreaming,
51   kAsyncCQUnary,
52   kAsyncCQClientStreaming,
53   kAsyncCQServerStreaming,
54   kAsyncCQBidiStreaming,
55 };
56 
57 /* Hijacks Echo RPC and fills in the expected values */
58 class HijackingInterceptor : public experimental::Interceptor {
59  public:
HijackingInterceptor(experimental::ClientRpcInfo * info)60   HijackingInterceptor(experimental::ClientRpcInfo* info) {
61     info_ = info;
62     // Make sure it is the right method
63     EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
64     EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
65   }
66 
Intercept(experimental::InterceptorBatchMethods * methods)67   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
68     bool hijack = false;
69     if (methods->QueryInterceptionHookPoint(
70             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
71       auto* map = methods->GetSendInitialMetadata();
72       // Check that we can see the test metadata
73       ASSERT_EQ(map->size(), static_cast<unsigned>(1));
74       auto iterator = map->begin();
75       EXPECT_EQ("testkey", iterator->first);
76       EXPECT_EQ("testvalue", iterator->second);
77       hijack = true;
78     }
79     if (methods->QueryInterceptionHookPoint(
80             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
81       EchoRequest req;
82       auto* buffer = methods->GetSerializedSendMessage();
83       auto copied_buffer = *buffer;
84       EXPECT_TRUE(
85           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
86               .ok());
87       EXPECT_EQ(req.message(), "Hello");
88     }
89     if (methods->QueryInterceptionHookPoint(
90             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
91       // Got nothing to do here for now
92     }
93     if (methods->QueryInterceptionHookPoint(
94             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
95       auto* map = methods->GetRecvInitialMetadata();
96       // Got nothing better to do here for now
97       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
98     }
99     if (methods->QueryInterceptionHookPoint(
100             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
101       EchoResponse* resp =
102           static_cast<EchoResponse*>(methods->GetRecvMessage());
103       // Check that we got the hijacked message, and re-insert the expected
104       // message
105       EXPECT_EQ(resp->message(), "Hello1");
106       resp->set_message("Hello");
107     }
108     if (methods->QueryInterceptionHookPoint(
109             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
110       auto* map = methods->GetRecvTrailingMetadata();
111       bool found = false;
112       // Check that we received the metadata as an echo
113       for (const auto& pair : *map) {
114         found = pair.first.starts_with("testkey") &&
115                 pair.second.starts_with("testvalue");
116         if (found) break;
117       }
118       EXPECT_EQ(found, true);
119       auto* status = methods->GetRecvStatus();
120       EXPECT_EQ(status->ok(), true);
121     }
122     if (methods->QueryInterceptionHookPoint(
123             experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
124       auto* map = methods->GetRecvInitialMetadata();
125       // Got nothing better to do here at the moment
126       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
127     }
128     if (methods->QueryInterceptionHookPoint(
129             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
130       // Insert a different message than expected
131       EchoResponse* resp =
132           static_cast<EchoResponse*>(methods->GetRecvMessage());
133       resp->set_message("Hello1");
134     }
135     if (methods->QueryInterceptionHookPoint(
136             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
137       auto* map = methods->GetRecvTrailingMetadata();
138       // insert the metadata that we want
139       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
140       map->insert(std::make_pair("testkey", "testvalue"));
141       auto* status = methods->GetRecvStatus();
142       *status = Status(StatusCode::OK, "");
143     }
144     if (hijack) {
145       methods->Hijack();
146     } else {
147       methods->Proceed();
148     }
149   }
150 
151  private:
152   experimental::ClientRpcInfo* info_;
153 };
154 
155 class HijackingInterceptorFactory
156     : public experimental::ClientInterceptorFactoryInterface {
157  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)158   virtual experimental::Interceptor* CreateClientInterceptor(
159       experimental::ClientRpcInfo* info) override {
160     return new HijackingInterceptor(info);
161   }
162 };
163 
164 class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
165  public:
HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo * info)166   HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) {
167     info_ = info;
168     // Make sure it is the right method
169     EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
170   }
171 
Intercept(experimental::InterceptorBatchMethods * methods)172   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
173     if (methods->QueryInterceptionHookPoint(
174             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
175       auto* map = methods->GetSendInitialMetadata();
176       // Check that we can see the test metadata
177       ASSERT_EQ(map->size(), static_cast<unsigned>(1));
178       auto iterator = map->begin();
179       EXPECT_EQ("testkey", iterator->first);
180       EXPECT_EQ("testvalue", iterator->second);
181       // Make a copy of the map
182       metadata_map_ = *map;
183     }
184     if (methods->QueryInterceptionHookPoint(
185             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
186       EchoRequest req;
187       auto* buffer = methods->GetSerializedSendMessage();
188       auto copied_buffer = *buffer;
189       EXPECT_TRUE(
190           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
191               .ok());
192       EXPECT_EQ(req.message(), "Hello");
193       req_ = req;
194       stub_ = grpc::testing::EchoTestService::NewStub(
195           methods->GetInterceptedChannel());
196       ctx_.AddMetadata(metadata_map_.begin()->first,
197                        metadata_map_.begin()->second);
198       stub_->experimental_async()->Echo(&ctx_, &req_, &resp_,
199                                         [this, methods](Status s) {
200                                           EXPECT_EQ(s.ok(), true);
201                                           EXPECT_EQ(resp_.message(), "Hello");
202                                           methods->Hijack();
203                                         });
204       // This is a Unary RPC and we have got nothing interesting to do in the
205       // PRE_SEND_CLOSE interception hook point for this interceptor, so let's
206       // return here. (We do not want to call methods->Proceed(). When the new
207       // RPC returns, we will call methods->Hijack() instead.)
208       return;
209     }
210     if (methods->QueryInterceptionHookPoint(
211             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
212       // Got nothing to do here for now
213     }
214     if (methods->QueryInterceptionHookPoint(
215             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
216       auto* map = methods->GetRecvInitialMetadata();
217       // Got nothing better to do here for now
218       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
219     }
220     if (methods->QueryInterceptionHookPoint(
221             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
222       EchoResponse* resp =
223           static_cast<EchoResponse*>(methods->GetRecvMessage());
224       // Check that we got the hijacked message, and re-insert the expected
225       // message
226       EXPECT_EQ(resp->message(), "Hello");
227     }
228     if (methods->QueryInterceptionHookPoint(
229             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
230       auto* map = methods->GetRecvTrailingMetadata();
231       bool found = false;
232       // Check that we received the metadata as an echo
233       for (const auto& pair : *map) {
234         found = pair.first.starts_with("testkey") &&
235                 pair.second.starts_with("testvalue");
236         if (found) break;
237       }
238       EXPECT_EQ(found, true);
239       auto* status = methods->GetRecvStatus();
240       EXPECT_EQ(status->ok(), true);
241     }
242     if (methods->QueryInterceptionHookPoint(
243             experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
244       auto* map = methods->GetRecvInitialMetadata();
245       // Got nothing better to do here at the moment
246       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
247     }
248     if (methods->QueryInterceptionHookPoint(
249             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
250       // Insert a different message than expected
251       EchoResponse* resp =
252           static_cast<EchoResponse*>(methods->GetRecvMessage());
253       resp->set_message(resp_.message());
254     }
255     if (methods->QueryInterceptionHookPoint(
256             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
257       auto* map = methods->GetRecvTrailingMetadata();
258       // insert the metadata that we want
259       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
260       map->insert(std::make_pair("testkey", "testvalue"));
261       auto* status = methods->GetRecvStatus();
262       *status = Status(StatusCode::OK, "");
263     }
264 
265     methods->Proceed();
266   }
267 
268  private:
269   experimental::ClientRpcInfo* info_;
270   std::multimap<std::string, std::string> metadata_map_;
271   ClientContext ctx_;
272   EchoRequest req_;
273   EchoResponse resp_;
274   std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
275 };
276 
277 class HijackingInterceptorMakesAnotherCallFactory
278     : public experimental::ClientInterceptorFactoryInterface {
279  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)280   virtual experimental::Interceptor* CreateClientInterceptor(
281       experimental::ClientRpcInfo* info) override {
282     return new HijackingInterceptorMakesAnotherCall(info);
283   }
284 };
285 
286 class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
287  public:
BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)288   BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
289     info_ = info;
290   }
291 
Intercept(experimental::InterceptorBatchMethods * methods)292   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
293     bool hijack = false;
294     if (methods->QueryInterceptionHookPoint(
295             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
296       CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
297       hijack = true;
298     }
299     if (methods->QueryInterceptionHookPoint(
300             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
301       EchoRequest req;
302       auto* buffer = methods->GetSerializedSendMessage();
303       auto copied_buffer = *buffer;
304       EXPECT_TRUE(
305           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
306               .ok());
307       EXPECT_EQ(req.message().find("Hello"), 0u);
308       msg = req.message();
309     }
310     if (methods->QueryInterceptionHookPoint(
311             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
312       // Got nothing to do here for now
313     }
314     if (methods->QueryInterceptionHookPoint(
315             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
316       CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
317                     "testvalue");
318       auto* status = methods->GetRecvStatus();
319       EXPECT_EQ(status->ok(), true);
320     }
321     if (methods->QueryInterceptionHookPoint(
322             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
323       EchoResponse* resp =
324           static_cast<EchoResponse*>(methods->GetRecvMessage());
325       resp->set_message(msg);
326     }
327     if (methods->QueryInterceptionHookPoint(
328             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
329       EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
330                     ->message()
331                     .find("Hello"),
332                 0u);
333     }
334     if (methods->QueryInterceptionHookPoint(
335             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
336       auto* map = methods->GetRecvTrailingMetadata();
337       // insert the metadata that we want
338       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
339       map->insert(std::make_pair("testkey", "testvalue"));
340       auto* status = methods->GetRecvStatus();
341       *status = Status(StatusCode::OK, "");
342     }
343     if (hijack) {
344       methods->Hijack();
345     } else {
346       methods->Proceed();
347     }
348   }
349 
350  private:
351   experimental::ClientRpcInfo* info_;
352   std::string msg;
353 };
354 
355 class ClientStreamingRpcHijackingInterceptor
356     : public experimental::Interceptor {
357  public:
ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)358   ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
359     info_ = info;
360   }
Intercept(experimental::InterceptorBatchMethods * methods)361   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
362     bool hijack = false;
363     if (methods->QueryInterceptionHookPoint(
364             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
365       hijack = true;
366     }
367     if (methods->QueryInterceptionHookPoint(
368             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
369       if (++count_ > 10) {
370         methods->FailHijackedSendMessage();
371       }
372     }
373     if (methods->QueryInterceptionHookPoint(
374             experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
375       EXPECT_FALSE(got_failed_send_);
376       got_failed_send_ = !methods->GetSendMessageStatus();
377     }
378     if (methods->QueryInterceptionHookPoint(
379             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
380       auto* status = methods->GetRecvStatus();
381       *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
382     }
383     if (hijack) {
384       methods->Hijack();
385     } else {
386       methods->Proceed();
387     }
388   }
389 
GotFailedSend()390   static bool GotFailedSend() { return got_failed_send_; }
391 
392  private:
393   experimental::ClientRpcInfo* info_;
394   int count_ = 0;
395   static bool got_failed_send_;
396 };
397 
398 bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
399 
400 class ClientStreamingRpcHijackingInterceptorFactory
401     : public experimental::ClientInterceptorFactoryInterface {
402  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)403   virtual experimental::Interceptor* CreateClientInterceptor(
404       experimental::ClientRpcInfo* info) override {
405     return new ClientStreamingRpcHijackingInterceptor(info);
406   }
407 };
408 
409 class ServerStreamingRpcHijackingInterceptor
410     : public experimental::Interceptor {
411  public:
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)412   ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
413     info_ = info;
414     got_failed_message_ = false;
415   }
416 
Intercept(experimental::InterceptorBatchMethods * methods)417   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
418     bool hijack = false;
419     if (methods->QueryInterceptionHookPoint(
420             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
421       auto* map = methods->GetSendInitialMetadata();
422       // Check that we can see the test metadata
423       ASSERT_EQ(map->size(), static_cast<unsigned>(1));
424       auto iterator = map->begin();
425       EXPECT_EQ("testkey", iterator->first);
426       EXPECT_EQ("testvalue", iterator->second);
427       hijack = true;
428     }
429     if (methods->QueryInterceptionHookPoint(
430             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
431       EchoRequest req;
432       auto* buffer = methods->GetSerializedSendMessage();
433       auto copied_buffer = *buffer;
434       EXPECT_TRUE(
435           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
436               .ok());
437       EXPECT_EQ(req.message(), "Hello");
438     }
439     if (methods->QueryInterceptionHookPoint(
440             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
441       // Got nothing to do here for now
442     }
443     if (methods->QueryInterceptionHookPoint(
444             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
445       auto* map = methods->GetRecvTrailingMetadata();
446       bool found = false;
447       // Check that we received the metadata as an echo
448       for (const auto& pair : *map) {
449         found = pair.first.starts_with("testkey") &&
450                 pair.second.starts_with("testvalue");
451         if (found) break;
452       }
453       EXPECT_EQ(found, true);
454       auto* status = methods->GetRecvStatus();
455       EXPECT_EQ(status->ok(), true);
456     }
457     if (methods->QueryInterceptionHookPoint(
458             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
459       if (++count_ > 10) {
460         methods->FailHijackedRecvMessage();
461       }
462       EchoResponse* resp =
463           static_cast<EchoResponse*>(methods->GetRecvMessage());
464       resp->set_message("Hello");
465     }
466     if (methods->QueryInterceptionHookPoint(
467             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
468       // Only the last message will be a failure
469       EXPECT_FALSE(got_failed_message_);
470       got_failed_message_ = methods->GetRecvMessage() == nullptr;
471     }
472     if (methods->QueryInterceptionHookPoint(
473             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
474       auto* map = methods->GetRecvTrailingMetadata();
475       // insert the metadata that we want
476       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
477       map->insert(std::make_pair("testkey", "testvalue"));
478       auto* status = methods->GetRecvStatus();
479       *status = Status(StatusCode::OK, "");
480     }
481     if (hijack) {
482       methods->Hijack();
483     } else {
484       methods->Proceed();
485     }
486   }
487 
GotFailedMessage()488   static bool GotFailedMessage() { return got_failed_message_; }
489 
490  private:
491   experimental::ClientRpcInfo* info_;
492   static bool got_failed_message_;
493   int count_ = 0;
494 };
495 
496 bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
497 
498 class ServerStreamingRpcHijackingInterceptorFactory
499     : public experimental::ClientInterceptorFactoryInterface {
500  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)501   virtual experimental::Interceptor* CreateClientInterceptor(
502       experimental::ClientRpcInfo* info) override {
503     return new ServerStreamingRpcHijackingInterceptor(info);
504   }
505 };
506 
507 class BidiStreamingRpcHijackingInterceptorFactory
508     : public experimental::ClientInterceptorFactoryInterface {
509  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)510   virtual experimental::Interceptor* CreateClientInterceptor(
511       experimental::ClientRpcInfo* info) override {
512     return new BidiStreamingRpcHijackingInterceptor(info);
513   }
514 };
515 
516 // The logging interceptor is for testing purposes only. It is used to verify
517 // that all the appropriate hook points are invoked for an RPC. The counts are
518 // reset each time a new object of LoggingInterceptor is created, so only a
519 // single RPC should be made on the channel before calling the Verify methods.
520 class LoggingInterceptor : public experimental::Interceptor {
521  public:
LoggingInterceptor(experimental::ClientRpcInfo *)522   LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) {
523     pre_send_initial_metadata_ = false;
524     pre_send_message_count_ = 0;
525     pre_send_close_ = false;
526     post_recv_initial_metadata_ = false;
527     post_recv_message_count_ = 0;
528     post_recv_status_ = false;
529   }
530 
Intercept(experimental::InterceptorBatchMethods * methods)531   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
532     if (methods->QueryInterceptionHookPoint(
533             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
534       auto* map = methods->GetSendInitialMetadata();
535       // Check that we can see the test metadata
536       ASSERT_EQ(map->size(), static_cast<unsigned>(1));
537       auto iterator = map->begin();
538       EXPECT_EQ("testkey", iterator->first);
539       EXPECT_EQ("testvalue", iterator->second);
540       ASSERT_FALSE(pre_send_initial_metadata_);
541       pre_send_initial_metadata_ = true;
542     }
543     if (methods->QueryInterceptionHookPoint(
544             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
545       EchoRequest req;
546       auto* send_msg = methods->GetSendMessage();
547       if (send_msg == nullptr) {
548         // We did not get the non-serialized form of the message. Get the
549         // serialized form.
550         auto* buffer = methods->GetSerializedSendMessage();
551         auto copied_buffer = *buffer;
552         EchoRequest req;
553         EXPECT_TRUE(
554             SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
555                 .ok());
556         EXPECT_EQ(req.message(), "Hello");
557       } else {
558         EXPECT_EQ(
559             static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
560             0u);
561       }
562       auto* buffer = methods->GetSerializedSendMessage();
563       auto copied_buffer = *buffer;
564       EXPECT_TRUE(
565           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
566               .ok());
567       EXPECT_TRUE(req.message().find("Hello") == 0u);
568       pre_send_message_count_++;
569     }
570     if (methods->QueryInterceptionHookPoint(
571             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
572       // Got nothing to do here for now
573       pre_send_close_ = true;
574     }
575     if (methods->QueryInterceptionHookPoint(
576             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
577       auto* map = methods->GetRecvInitialMetadata();
578       // Got nothing better to do here for now
579       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
580       post_recv_initial_metadata_ = true;
581     }
582     if (methods->QueryInterceptionHookPoint(
583             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
584       EchoResponse* resp =
585           static_cast<EchoResponse*>(methods->GetRecvMessage());
586       if (resp != nullptr) {
587         EXPECT_TRUE(resp->message().find("Hello") == 0u);
588         post_recv_message_count_++;
589       }
590     }
591     if (methods->QueryInterceptionHookPoint(
592             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
593       auto* map = methods->GetRecvTrailingMetadata();
594       bool found = false;
595       // Check that we received the metadata as an echo
596       for (const auto& pair : *map) {
597         found = pair.first.starts_with("testkey") &&
598                 pair.second.starts_with("testvalue");
599         if (found) break;
600       }
601       EXPECT_EQ(found, true);
602       auto* status = methods->GetRecvStatus();
603       EXPECT_EQ(status->ok(), true);
604       post_recv_status_ = true;
605     }
606     methods->Proceed();
607   }
608 
VerifyCall(RPCType type)609   static void VerifyCall(RPCType type) {
610     switch (type) {
611       case RPCType::kSyncUnary:
612       case RPCType::kAsyncCQUnary:
613         VerifyUnaryCall();
614         break;
615       case RPCType::kSyncClientStreaming:
616       case RPCType::kAsyncCQClientStreaming:
617         VerifyClientStreamingCall();
618         break;
619       case RPCType::kSyncServerStreaming:
620       case RPCType::kAsyncCQServerStreaming:
621         VerifyServerStreamingCall();
622         break;
623       case RPCType::kSyncBidiStreaming:
624       case RPCType::kAsyncCQBidiStreaming:
625         VerifyBidiStreamingCall();
626         break;
627     }
628   }
629 
VerifyCallCommon()630   static void VerifyCallCommon() {
631     EXPECT_TRUE(pre_send_initial_metadata_);
632     EXPECT_TRUE(pre_send_close_);
633     EXPECT_TRUE(post_recv_initial_metadata_);
634     EXPECT_TRUE(post_recv_status_);
635   }
636 
VerifyUnaryCall()637   static void VerifyUnaryCall() {
638     VerifyCallCommon();
639     EXPECT_EQ(pre_send_message_count_, 1);
640     EXPECT_EQ(post_recv_message_count_, 1);
641   }
642 
VerifyClientStreamingCall()643   static void VerifyClientStreamingCall() {
644     VerifyCallCommon();
645     EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
646     EXPECT_EQ(post_recv_message_count_, 1);
647   }
648 
VerifyServerStreamingCall()649   static void VerifyServerStreamingCall() {
650     VerifyCallCommon();
651     EXPECT_EQ(pre_send_message_count_, 1);
652     EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
653   }
654 
VerifyBidiStreamingCall()655   static void VerifyBidiStreamingCall() {
656     VerifyCallCommon();
657     EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
658     EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
659   }
660 
661  private:
662   static bool pre_send_initial_metadata_;
663   static int pre_send_message_count_;
664   static bool pre_send_close_;
665   static bool post_recv_initial_metadata_;
666   static int post_recv_message_count_;
667   static bool post_recv_status_;
668 };
669 
670 bool LoggingInterceptor::pre_send_initial_metadata_;
671 int LoggingInterceptor::pre_send_message_count_;
672 bool LoggingInterceptor::pre_send_close_;
673 bool LoggingInterceptor::post_recv_initial_metadata_;
674 int LoggingInterceptor::post_recv_message_count_;
675 bool LoggingInterceptor::post_recv_status_;
676 
677 class LoggingInterceptorFactory
678     : public experimental::ClientInterceptorFactoryInterface {
679  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)680   virtual experimental::Interceptor* CreateClientInterceptor(
681       experimental::ClientRpcInfo* info) override {
682     return new LoggingInterceptor(info);
683   }
684 };
685 
686 class TestScenario {
687  public:
TestScenario(const RPCType & type)688   explicit TestScenario(const RPCType& type) : type_(type) {}
689 
type() const690   RPCType type() const { return type_; }
691 
692  private:
693   RPCType type_;
694 };
695 
CreateTestScenarios()696 std::vector<TestScenario> CreateTestScenarios() {
697   std::vector<TestScenario> scenarios;
698   scenarios.emplace_back(RPCType::kSyncUnary);
699   scenarios.emplace_back(RPCType::kSyncClientStreaming);
700   scenarios.emplace_back(RPCType::kSyncServerStreaming);
701   scenarios.emplace_back(RPCType::kSyncBidiStreaming);
702   scenarios.emplace_back(RPCType::kAsyncCQUnary);
703   scenarios.emplace_back(RPCType::kAsyncCQServerStreaming);
704   return scenarios;
705 }
706 
707 class ParameterizedClientInterceptorsEnd2endTest
708     : public ::testing::TestWithParam<TestScenario> {
709  protected:
ParameterizedClientInterceptorsEnd2endTest()710   ParameterizedClientInterceptorsEnd2endTest() {
711     int port = grpc_pick_unused_port_or_die();
712 
713     ServerBuilder builder;
714     server_address_ = "localhost:" + std::to_string(port);
715     builder.AddListeningPort(server_address_, InsecureServerCredentials());
716     builder.RegisterService(&service_);
717     server_ = builder.BuildAndStart();
718   }
719 
~ParameterizedClientInterceptorsEnd2endTest()720   ~ParameterizedClientInterceptorsEnd2endTest() { server_->Shutdown(); }
721 
SendRPC(const std::shared_ptr<Channel> & channel)722   void SendRPC(const std::shared_ptr<Channel>& channel) {
723     switch (GetParam().type()) {
724       case RPCType::kSyncUnary:
725         MakeCall(channel);
726         break;
727       case RPCType::kSyncClientStreaming:
728         MakeClientStreamingCall(channel);
729         break;
730       case RPCType::kSyncServerStreaming:
731         MakeServerStreamingCall(channel);
732         break;
733       case RPCType::kSyncBidiStreaming:
734         MakeBidiStreamingCall(channel);
735         break;
736       case RPCType::kAsyncCQUnary:
737         MakeAsyncCQCall(channel);
738         break;
739       case RPCType::kAsyncCQClientStreaming:
740         // TODO(yashykt) : Fill this out
741         break;
742       case RPCType::kAsyncCQServerStreaming:
743         MakeAsyncCQServerStreamingCall(channel);
744         break;
745       case RPCType::kAsyncCQBidiStreaming:
746         // TODO(yashykt) : Fill this out
747         break;
748     }
749   }
750 
751   std::string server_address_;
752   EchoTestServiceStreamingImpl service_;
753   std::unique_ptr<Server> server_;
754 };
755 
TEST_P(ParameterizedClientInterceptorsEnd2endTest,ClientInterceptorLoggingTest)756 TEST_P(ParameterizedClientInterceptorsEnd2endTest,
757        ClientInterceptorLoggingTest) {
758   ChannelArguments args;
759   DummyInterceptor::Reset();
760   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
761       creators;
762   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
763       new LoggingInterceptorFactory()));
764   // Add 20 dummy interceptors
765   for (auto i = 0; i < 20; i++) {
766     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
767         new DummyInterceptorFactory()));
768   }
769   auto channel = experimental::CreateCustomChannelWithInterceptors(
770       server_address_, InsecureChannelCredentials(), args, std::move(creators));
771   SendRPC(channel);
772   LoggingInterceptor::VerifyCall(GetParam().type());
773   // Make sure all 20 dummy interceptors were run
774   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
775 }
776 
777 INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
778                          ParameterizedClientInterceptorsEnd2endTest,
779                          ::testing::ValuesIn(CreateTestScenarios()));
780 
781 class ClientInterceptorsEnd2endTest
782     : public ::testing::TestWithParam<TestScenario> {
783  protected:
ClientInterceptorsEnd2endTest()784   ClientInterceptorsEnd2endTest() {
785     int port = grpc_pick_unused_port_or_die();
786 
787     ServerBuilder builder;
788     server_address_ = "localhost:" + std::to_string(port);
789     builder.AddListeningPort(server_address_, InsecureServerCredentials());
790     builder.RegisterService(&service_);
791     server_ = builder.BuildAndStart();
792   }
793 
~ClientInterceptorsEnd2endTest()794   ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); }
795 
796   std::string server_address_;
797   TestServiceImpl service_;
798   std::unique_ptr<Server> server_;
799 };
800 
TEST_F(ClientInterceptorsEnd2endTest,LameChannelClientInterceptorHijackingTest)801 TEST_F(ClientInterceptorsEnd2endTest,
802        LameChannelClientInterceptorHijackingTest) {
803   ChannelArguments args;
804   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
805       creators;
806   creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
807       new HijackingInterceptorFactory()));
808   auto channel = experimental::CreateCustomChannelWithInterceptors(
809       server_address_, nullptr, args, std::move(creators));
810   MakeCall(channel);
811 }
812 
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingTest)813 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
814   ChannelArguments args;
815   DummyInterceptor::Reset();
816   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
817       creators;
818   // Add 20 dummy interceptors before hijacking interceptor
819   creators.reserve(20);
820   for (auto i = 0; i < 20; i++) {
821     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
822         new DummyInterceptorFactory()));
823   }
824   creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
825       new HijackingInterceptorFactory()));
826   // Add 20 dummy interceptors after hijacking interceptor
827   for (auto i = 0; i < 20; i++) {
828     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
829         new DummyInterceptorFactory()));
830   }
831   auto channel = experimental::CreateCustomChannelWithInterceptors(
832       server_address_, InsecureChannelCredentials(), args, std::move(creators));
833   MakeCall(channel);
834   // Make sure only 20 dummy interceptors were run
835   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
836 }
837 
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorLogThenHijackTest)838 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
839   ChannelArguments args;
840   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
841       creators;
842   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
843       new LoggingInterceptorFactory()));
844   creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
845       new HijackingInterceptorFactory()));
846   auto channel = experimental::CreateCustomChannelWithInterceptors(
847       server_address_, InsecureChannelCredentials(), args, std::move(creators));
848   MakeCall(channel);
849   LoggingInterceptor::VerifyUnaryCall();
850 }
851 
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingMakesAnotherCallTest)852 TEST_F(ClientInterceptorsEnd2endTest,
853        ClientInterceptorHijackingMakesAnotherCallTest) {
854   ChannelArguments args;
855   DummyInterceptor::Reset();
856   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
857       creators;
858   // Add 5 dummy interceptors before hijacking interceptor
859   creators.reserve(5);
860   for (auto i = 0; i < 5; i++) {
861     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
862         new DummyInterceptorFactory()));
863   }
864   creators.push_back(
865       std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
866           new HijackingInterceptorMakesAnotherCallFactory()));
867   // Add 7 dummy interceptors after hijacking interceptor
868   for (auto i = 0; i < 7; i++) {
869     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
870         new DummyInterceptorFactory()));
871   }
872   auto channel = server_->experimental().InProcessChannelWithInterceptors(
873       args, std::move(creators));
874 
875   MakeCall(channel);
876   // Make sure all interceptors were run once, since the hijacking interceptor
877   // makes an RPC on the intercepted channel
878   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12);
879 }
880 
881 class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
882  protected:
ClientInterceptorsCallbackEnd2endTest()883   ClientInterceptorsCallbackEnd2endTest() {
884     int port = grpc_pick_unused_port_or_die();
885 
886     ServerBuilder builder;
887     server_address_ = "localhost:" + std::to_string(port);
888     builder.AddListeningPort(server_address_, InsecureServerCredentials());
889     builder.RegisterService(&service_);
890     server_ = builder.BuildAndStart();
891   }
892 
~ClientInterceptorsCallbackEnd2endTest()893   ~ClientInterceptorsCallbackEnd2endTest() { server_->Shutdown(); }
894 
895   std::string server_address_;
896   TestServiceImpl service_;
897   std::unique_ptr<Server> server_;
898 };
899 
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorLoggingTestWithCallback)900 TEST_F(ClientInterceptorsCallbackEnd2endTest,
901        ClientInterceptorLoggingTestWithCallback) {
902   ChannelArguments args;
903   DummyInterceptor::Reset();
904   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
905       creators;
906   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
907       new LoggingInterceptorFactory()));
908   // Add 20 dummy interceptors
909   for (auto i = 0; i < 20; i++) {
910     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
911         new DummyInterceptorFactory()));
912   }
913   auto channel = server_->experimental().InProcessChannelWithInterceptors(
914       args, std::move(creators));
915   MakeCallbackCall(channel);
916   LoggingInterceptor::VerifyUnaryCall();
917   // Make sure all 20 dummy interceptors were run
918   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
919 }
920 
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorFactoryAllowsNullptrReturn)921 TEST_F(ClientInterceptorsCallbackEnd2endTest,
922        ClientInterceptorFactoryAllowsNullptrReturn) {
923   ChannelArguments args;
924   DummyInterceptor::Reset();
925   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
926       creators;
927   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
928       new LoggingInterceptorFactory()));
929   // Add 20 dummy interceptors and 20 null interceptors
930   for (auto i = 0; i < 20; i++) {
931     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
932         new DummyInterceptorFactory()));
933     creators.push_back(
934         std::unique_ptr<NullInterceptorFactory>(new NullInterceptorFactory()));
935   }
936   auto channel = server_->experimental().InProcessChannelWithInterceptors(
937       args, std::move(creators));
938   MakeCallbackCall(channel);
939   LoggingInterceptor::VerifyUnaryCall();
940   // Make sure all 20 dummy interceptors were run
941   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
942 }
943 
944 class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
945  protected:
ClientInterceptorsStreamingEnd2endTest()946   ClientInterceptorsStreamingEnd2endTest() {
947     int port = grpc_pick_unused_port_or_die();
948 
949     ServerBuilder builder;
950     server_address_ = "localhost:" + std::to_string(port);
951     builder.AddListeningPort(server_address_, InsecureServerCredentials());
952     builder.RegisterService(&service_);
953     server_ = builder.BuildAndStart();
954   }
955 
~ClientInterceptorsStreamingEnd2endTest()956   ~ClientInterceptorsStreamingEnd2endTest() { server_->Shutdown(); }
957 
958   std::string server_address_;
959   EchoTestServiceStreamingImpl service_;
960   std::unique_ptr<Server> server_;
961 };
962 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingTest)963 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
964   ChannelArguments args;
965   DummyInterceptor::Reset();
966   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
967       creators;
968   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
969       new LoggingInterceptorFactory()));
970   // Add 20 dummy interceptors
971   for (auto i = 0; i < 20; i++) {
972     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
973         new DummyInterceptorFactory()));
974   }
975   auto channel = experimental::CreateCustomChannelWithInterceptors(
976       server_address_, InsecureChannelCredentials(), args, std::move(creators));
977   MakeClientStreamingCall(channel);
978   LoggingInterceptor::VerifyClientStreamingCall();
979   // Make sure all 20 dummy interceptors were run
980   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
981 }
982 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingTest)983 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
984   ChannelArguments args;
985   DummyInterceptor::Reset();
986   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
987       creators;
988   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
989       new LoggingInterceptorFactory()));
990   // Add 20 dummy interceptors
991   for (auto i = 0; i < 20; i++) {
992     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
993         new DummyInterceptorFactory()));
994   }
995   auto channel = experimental::CreateCustomChannelWithInterceptors(
996       server_address_, InsecureChannelCredentials(), args, std::move(creators));
997   MakeServerStreamingCall(channel);
998   LoggingInterceptor::VerifyServerStreamingCall();
999   // Make sure all 20 dummy interceptors were run
1000   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1001 }
1002 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingHijackingTest)1003 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
1004   ChannelArguments args;
1005   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1006       creators;
1007   creators.push_back(
1008       std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
1009           new ClientStreamingRpcHijackingInterceptorFactory()));
1010   auto channel = experimental::CreateCustomChannelWithInterceptors(
1011       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1012 
1013   auto stub = grpc::testing::EchoTestService::NewStub(channel);
1014   ClientContext ctx;
1015   EchoRequest req;
1016   EchoResponse resp;
1017   req.mutable_param()->set_echo_metadata(true);
1018   req.set_message("Hello");
1019   string expected_resp = "";
1020   auto writer = stub->RequestStream(&ctx, &resp);
1021   for (int i = 0; i < 10; i++) {
1022     EXPECT_TRUE(writer->Write(req));
1023     expected_resp += "Hello";
1024   }
1025   // The interceptor will reject the 11th message
1026   writer->Write(req);
1027   Status s = writer->Finish();
1028   EXPECT_EQ(s.ok(), false);
1029   EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
1030 }
1031 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingHijackingTest)1032 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
1033   ChannelArguments args;
1034   DummyInterceptor::Reset();
1035   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1036       creators;
1037   creators.push_back(
1038       std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
1039           new ServerStreamingRpcHijackingInterceptorFactory()));
1040   auto channel = experimental::CreateCustomChannelWithInterceptors(
1041       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1042   MakeServerStreamingCall(channel);
1043   EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1044 }
1045 
TEST_F(ClientInterceptorsStreamingEnd2endTest,AsyncCQServerStreamingHijackingTest)1046 TEST_F(ClientInterceptorsStreamingEnd2endTest,
1047        AsyncCQServerStreamingHijackingTest) {
1048   ChannelArguments args;
1049   DummyInterceptor::Reset();
1050   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1051       creators;
1052   creators.push_back(
1053       std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
1054           new ServerStreamingRpcHijackingInterceptorFactory()));
1055   auto channel = experimental::CreateCustomChannelWithInterceptors(
1056       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1057   MakeAsyncCQServerStreamingCall(channel);
1058   EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1059 }
1060 
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingHijackingTest)1061 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
1062   ChannelArguments args;
1063   DummyInterceptor::Reset();
1064   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1065       creators;
1066   creators.push_back(
1067       std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
1068           new BidiStreamingRpcHijackingInterceptorFactory()));
1069   auto channel = experimental::CreateCustomChannelWithInterceptors(
1070       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1071   MakeBidiStreamingCall(channel);
1072 }
1073 
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingTest)1074 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
1075   ChannelArguments args;
1076   DummyInterceptor::Reset();
1077   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1078       creators;
1079   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
1080       new LoggingInterceptorFactory()));
1081   // Add 20 dummy interceptors
1082   for (auto i = 0; i < 20; i++) {
1083     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1084         new DummyInterceptorFactory()));
1085   }
1086   auto channel = experimental::CreateCustomChannelWithInterceptors(
1087       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1088   MakeBidiStreamingCall(channel);
1089   LoggingInterceptor::VerifyBidiStreamingCall();
1090   // Make sure all 20 dummy interceptors were run
1091   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1092 }
1093 
1094 class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
1095  protected:
ClientGlobalInterceptorEnd2endTest()1096   ClientGlobalInterceptorEnd2endTest() {
1097     int port = grpc_pick_unused_port_or_die();
1098 
1099     ServerBuilder builder;
1100     server_address_ = "localhost:" + std::to_string(port);
1101     builder.AddListeningPort(server_address_, InsecureServerCredentials());
1102     builder.RegisterService(&service_);
1103     server_ = builder.BuildAndStart();
1104   }
1105 
~ClientGlobalInterceptorEnd2endTest()1106   ~ClientGlobalInterceptorEnd2endTest() { server_->Shutdown(); }
1107 
1108   std::string server_address_;
1109   TestServiceImpl service_;
1110   std::unique_ptr<Server> server_;
1111 };
1112 
TEST_F(ClientGlobalInterceptorEnd2endTest,DummyGlobalInterceptor)1113 TEST_F(ClientGlobalInterceptorEnd2endTest, DummyGlobalInterceptor) {
1114   // We should ideally be registering a global interceptor only once per
1115   // process, but for the purposes of testing, it should be fine to modify the
1116   // registered global interceptor when there are no ongoing gRPC operations
1117   DummyInterceptorFactory global_factory;
1118   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1119   ChannelArguments args;
1120   DummyInterceptor::Reset();
1121   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1122       creators;
1123   // Add 20 dummy interceptors
1124   creators.reserve(20);
1125   for (auto i = 0; i < 20; i++) {
1126     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1127         new DummyInterceptorFactory()));
1128   }
1129   auto channel = experimental::CreateCustomChannelWithInterceptors(
1130       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1131   MakeCall(channel);
1132   // Make sure all 20 dummy interceptors were run with the global interceptor
1133   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 21);
1134   experimental::TestOnlyResetGlobalClientInterceptorFactory();
1135 }
1136 
TEST_F(ClientGlobalInterceptorEnd2endTest,LoggingGlobalInterceptor)1137 TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
1138   // We should ideally be registering a global interceptor only once per
1139   // process, but for the purposes of testing, it should be fine to modify the
1140   // registered global interceptor when there are no ongoing gRPC operations
1141   LoggingInterceptorFactory global_factory;
1142   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1143   ChannelArguments args;
1144   DummyInterceptor::Reset();
1145   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1146       creators;
1147   // Add 20 dummy interceptors
1148   creators.reserve(20);
1149   for (auto i = 0; i < 20; i++) {
1150     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1151         new DummyInterceptorFactory()));
1152   }
1153   auto channel = experimental::CreateCustomChannelWithInterceptors(
1154       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1155   MakeCall(channel);
1156   LoggingInterceptor::VerifyUnaryCall();
1157   // Make sure all 20 dummy interceptors were run
1158   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1159   experimental::TestOnlyResetGlobalClientInterceptorFactory();
1160 }
1161 
TEST_F(ClientGlobalInterceptorEnd2endTest,HijackingGlobalInterceptor)1162 TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
1163   // We should ideally be registering a global interceptor only once per
1164   // process, but for the purposes of testing, it should be fine to modify the
1165   // registered global interceptor when there are no ongoing gRPC operations
1166   HijackingInterceptorFactory global_factory;
1167   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1168   ChannelArguments args;
1169   DummyInterceptor::Reset();
1170   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1171       creators;
1172   // Add 20 dummy interceptors
1173   creators.reserve(20);
1174   for (auto i = 0; i < 20; i++) {
1175     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1176         new DummyInterceptorFactory()));
1177   }
1178   auto channel = experimental::CreateCustomChannelWithInterceptors(
1179       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1180   MakeCall(channel);
1181   // Make sure all 20 dummy interceptors were run
1182   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1183   experimental::TestOnlyResetGlobalClientInterceptorFactory();
1184 }
1185 
1186 }  // namespace
1187 }  // namespace testing
1188 }  // namespace grpc
1189 
main(int argc,char ** argv)1190 int main(int argc, char** argv) {
1191   grpc::testing::TestEnvironment env(argc, argv);
1192   ::testing::InitGoogleTest(&argc, argv);
1193   return RUN_ALL_TESTS();
1194 }
1195