• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/core/lib/transport/interception_chain.h"
16 
17 #include <grpc/grpc.h>
18 #include <grpc/support/log.h>
19 
20 #include <memory>
21 
22 #include "absl/log/log.h"
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25 #include "src/core/lib/channel/promise_based_filter.h"
26 #include "src/core/lib/resource_quota/resource_quota.h"
27 #include "test/core/promise/poll_matcher.h"
28 
29 namespace grpc_core {
30 namespace {
31 
32 ///////////////////////////////////////////////////////////////////////////////
33 // Mutate metadata by annotating that it passed through a filter "x"
34 
AnnotatePassedThrough(ClientMetadata & md,int x)35 void AnnotatePassedThrough(ClientMetadata& md, int x) {
36   md.Append(absl::StrCat("passed-through-", x), Slice::FromCopiedString("true"),
37             [](absl::string_view, const Slice&) { Crash("unreachable"); });
38 }
39 
40 ///////////////////////////////////////////////////////////////////////////////
41 // CreationLog helps us reason about filter creation order by logging a small
42 // record of each filter's creation.
43 
44 struct CreationLogEntry {
45   size_t filter_instance_id;
46   size_t type_tag;
47 
operator ==grpc_core::__anon9d967e2b0111::CreationLogEntry48   bool operator==(const CreationLogEntry& other) const {
49     return filter_instance_id == other.filter_instance_id &&
50            type_tag == other.type_tag;
51   }
52 
operator <<(std::ostream & os,const CreationLogEntry & entry)53   friend std::ostream& operator<<(std::ostream& os,
54                                   const CreationLogEntry& entry) {
55     return os << "{filter_instance_id=" << entry.filter_instance_id
56               << ", type_tag=" << entry.type_tag << "}";
57   }
58 };
59 
60 struct CreationLog {
61   struct RawPointerChannelArgTag {};
ChannelArgNamegrpc_core::__anon9d967e2b0111::CreationLog62   static absl::string_view ChannelArgName() { return "creation_log"; }
63   std::vector<CreationLogEntry> entries;
64 };
65 
MaybeLogCreation(const ChannelArgs & channel_args,ChannelFilter::Args filter_args,size_t type_tag)66 void MaybeLogCreation(const ChannelArgs& channel_args,
67                       ChannelFilter::Args filter_args, size_t type_tag) {
68   auto* log = channel_args.GetPointer<CreationLog>("creation_log");
69   if (log == nullptr) return;
70   log->entries.push_back(CreationLogEntry{filter_args.instance_id(), type_tag});
71 }
72 
73 ///////////////////////////////////////////////////////////////////////////////
74 // Test call filter
75 
76 template <int I>
77 class TestFilter {
78  public:
79   class Call {
80    public:
OnClientInitialMetadata(ClientMetadata & md)81     void OnClientInitialMetadata(ClientMetadata& md) {
82       AnnotatePassedThrough(md, I);
83     }
84     static const NoInterceptor OnServerInitialMetadata;
85     static const NoInterceptor OnClientToServerMessage;
86     static const NoInterceptor OnClientToServerHalfClose;
87     static const NoInterceptor OnServerToClientMessage;
88     static const NoInterceptor OnServerTrailingMetadata;
89     static const NoInterceptor OnFinalize;
90   };
91 
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)92   static absl::StatusOr<std::unique_ptr<TestFilter<I>>> Create(
93       const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
94     MaybeLogCreation(channel_args, filter_args, I);
95     return std::make_unique<TestFilter<I>>();
96   }
97 
98  private:
99   std::unique_ptr<int> i_ = std::make_unique<int>(I);
100 };
101 
102 template <int I>
103 const NoInterceptor TestFilter<I>::Call::OnServerInitialMetadata;
104 template <int I>
105 const NoInterceptor TestFilter<I>::Call::OnClientToServerMessage;
106 template <int I>
107 const NoInterceptor TestFilter<I>::Call::OnClientToServerHalfClose;
108 template <int I>
109 const NoInterceptor TestFilter<I>::Call::OnServerToClientMessage;
110 template <int I>
111 const NoInterceptor TestFilter<I>::Call::OnServerTrailingMetadata;
112 template <int I>
113 const NoInterceptor TestFilter<I>::Call::OnFinalize;
114 
115 ///////////////////////////////////////////////////////////////////////////////
116 // Test call filter that fails to instantiate
117 
118 template <int I>
119 class FailsToInstantiateFilter {
120  public:
121   class Call {
122    public:
123     static const NoInterceptor OnClientInitialMetadata;
124     static const NoInterceptor OnServerInitialMetadata;
125     static const NoInterceptor OnClientToServerMessage;
126     static const NoInterceptor OnClientToServerHalfClose;
127     static const NoInterceptor OnServerToClientMessage;
128     static const NoInterceptor OnServerTrailingMetadata;
129     static const NoInterceptor OnFinalize;
130   };
131 
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)132   static absl::StatusOr<std::unique_ptr<FailsToInstantiateFilter<I>>> Create(
133       const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
134     MaybeLogCreation(channel_args, filter_args, I);
135     return absl::InternalError(absl::StrCat("�� failed to instantiate ", I));
136   }
137 };
138 
139 template <int I>
140 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientInitialMetadata;
141 template <int I>
142 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerInitialMetadata;
143 template <int I>
144 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientToServerMessage;
145 template <int I>
146 const NoInterceptor
147     FailsToInstantiateFilter<I>::Call::OnClientToServerHalfClose;
148 template <int I>
149 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerToClientMessage;
150 template <int I>
151 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerTrailingMetadata;
152 template <int I>
153 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnFinalize;
154 
155 ///////////////////////////////////////////////////////////////////////////////
156 // Test call interceptor - consumes calls
157 
158 template <int I>
159 class TestConsumingInterceptor final : public Interceptor {
160  public:
InterceptCall(UnstartedCallHandler unstarted_call_handler)161   void InterceptCall(UnstartedCallHandler unstarted_call_handler) override {
162     Consume(std::move(unstarted_call_handler))
163         .PushServerTrailingMetadata(
164             ServerMetadataFromStatus(GRPC_STATUS_INTERNAL, "�� consumed"));
165   }
Orphaned()166   void Orphaned() override {}
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)167   static absl::StatusOr<RefCountedPtr<TestConsumingInterceptor<I>>> Create(
168       const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
169     MaybeLogCreation(channel_args, filter_args, I);
170     return MakeRefCounted<TestConsumingInterceptor<I>>();
171   }
172 };
173 
174 ///////////////////////////////////////////////////////////////////////////////
175 // Test call interceptor - passes through calls
176 
177 template <int I>
178 class TestPassThroughInterceptor final : public Interceptor {
179  public:
InterceptCall(UnstartedCallHandler unstarted_call_handler)180   void InterceptCall(UnstartedCallHandler unstarted_call_handler) override {
181     PassThrough(std::move(unstarted_call_handler));
182   }
Orphaned()183   void Orphaned() override {}
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)184   static absl::StatusOr<RefCountedPtr<TestPassThroughInterceptor<I>>> Create(
185       const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
186     MaybeLogCreation(channel_args, filter_args, I);
187     return MakeRefCounted<TestPassThroughInterceptor<I>>();
188   }
189 };
190 
191 ///////////////////////////////////////////////////////////////////////////////
192 // Test call interceptor - fails to instantiate
193 
194 template <int I>
195 class TestFailingInterceptor final : public Interceptor {
196  public:
InterceptCall(UnstartedCallHandler unstarted_call_handler)197   void InterceptCall(UnstartedCallHandler unstarted_call_handler) override {
198     Crash("unreachable");
199   }
Orphaned()200   void Orphaned() override {}
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)201   static absl::StatusOr<RefCountedPtr<TestFailingInterceptor<I>>> Create(
202       const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
203     MaybeLogCreation(channel_args, filter_args, I);
204     return absl::InternalError(absl::StrCat("�� failed to instantiate ", I));
205   }
206 };
207 
208 ///////////////////////////////////////////////////////////////////////////////
209 // Test call interceptor - hijacks calls
210 
211 template <int I>
212 class TestHijackingInterceptor final : public Interceptor {
213  public:
InterceptCall(UnstartedCallHandler unstarted_call_handler)214   void InterceptCall(UnstartedCallHandler unstarted_call_handler) override {
215     unstarted_call_handler.SpawnInfallible(
216         "hijack", [this, unstarted_call_handler]() mutable {
217           return Map(Hijack(std::move(unstarted_call_handler)),
218                      [](ValueOrFailure<HijackedCall> hijacked_call) {
219                        ForwardCall(
220                            hijacked_call.value().original_call_handler(),
221                            hijacked_call.value().MakeCall());
222                      });
223         });
224   }
Orphaned()225   void Orphaned() override {}
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)226   static absl::StatusOr<RefCountedPtr<TestHijackingInterceptor<I>>> Create(
227       const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
228     MaybeLogCreation(channel_args, filter_args, I);
229     return MakeRefCounted<TestHijackingInterceptor<I>>();
230   }
231 };
232 
233 ///////////////////////////////////////////////////////////////////////////////
234 // Test fixture
235 
236 class InterceptionChainTest : public ::testing::Test {
237  protected:
InterceptionChainTest()238   InterceptionChainTest() {}
~InterceptionChainTest()239   ~InterceptionChainTest() override {}
240 
destination()241   RefCountedPtr<UnstartedCallDestination> destination() { return destination_; }
242 
243   struct FinishedCall {
244     CallInitiator call;
245     ClientMetadataHandle client_metadata;
246     ServerMetadataHandle server_metadata;
247   };
248 
249   // Run a call through a UnstartedCallDestination until it's complete.
RunCall(UnstartedCallDestination * destination)250   FinishedCall RunCall(UnstartedCallDestination* destination) {
251     auto arena = call_arena_allocator_->MakeArena();
252     arena->SetContext<grpc_event_engine::experimental::EventEngine>(
253         event_engine_.get());
254     auto call = MakeCallPair(Arena::MakePooledForOverwrite<ClientMetadata>(),
255                              std::move(arena));
256     Poll<ServerMetadataHandle> trailing_md;
257     call.initiator.SpawnInfallible(
258         "run_call", [destination, &call, &trailing_md]() mutable {
259           LOG(INFO) << "�� start call";
260           destination->StartCall(std::move(call.handler));
261           return Map(call.initiator.PullServerTrailingMetadata(),
262                      [&trailing_md](ServerMetadataHandle md) {
263                        trailing_md = std::move(md);
264                      });
265         });
266     EXPECT_THAT(trailing_md, IsReady());
267     return FinishedCall{std::move(call.initiator), destination_->TakeMetadata(),
268                         std::move(trailing_md.value())};
269   }
270 
271  private:
272   class Destination final : public UnstartedCallDestination {
273    public:
StartCall(UnstartedCallHandler unstarted_call_handler)274     void StartCall(UnstartedCallHandler unstarted_call_handler) override {
275       LOG(INFO) << "�� started call: metadata="
276                 << unstarted_call_handler.UnprocessedClientInitialMetadata()
277                        .DebugString();
278       EXPECT_EQ(metadata_.get(), nullptr);
279       metadata_ = Arena::MakePooledForOverwrite<ClientMetadata>();
280       *metadata_ =
281           unstarted_call_handler.UnprocessedClientInitialMetadata().Copy();
282       unstarted_call_handler.PushServerTrailingMetadata(
283           ServerMetadataFromStatus(GRPC_STATUS_INTERNAL, "�� cancelled"));
284     }
285 
Orphaned()286     void Orphaned() override {}
287 
TakeMetadata()288     ClientMetadataHandle TakeMetadata() { return std::move(metadata_); }
289 
290    private:
291     ClientMetadataHandle metadata_;
292   };
293   std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_ =
294       grpc_event_engine::experimental::GetDefaultEventEngine();
295   RefCountedPtr<Destination> destination_ = MakeRefCounted<Destination>();
296   RefCountedPtr<CallArenaAllocator> call_arena_allocator_ =
297       MakeRefCounted<CallArenaAllocator>(
298           ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
299               "test"),
300           1024);
301 };
302 
303 ///////////////////////////////////////////////////////////////////////////////
304 // Tests begin
305 
TEST_F(InterceptionChainTest,Empty)306 TEST_F(InterceptionChainTest, Empty) {
307   auto r = InterceptionChainBuilder(ChannelArgs()).Build(destination());
308   ASSERT_TRUE(r.ok()) << r.status();
309   auto finished_call = RunCall(r.value().get());
310   EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
311             GRPC_STATUS_INTERNAL);
312   EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
313                 ->as_string_view(),
314             "�� cancelled");
315   EXPECT_NE(finished_call.client_metadata, nullptr);
316 }
317 
TEST_F(InterceptionChainTest,PassThrough)318 TEST_F(InterceptionChainTest, PassThrough) {
319   auto r = InterceptionChainBuilder(ChannelArgs())
320                .Add<TestPassThroughInterceptor<1>>()
321                .Build(destination());
322   ASSERT_TRUE(r.ok()) << r.status();
323   auto finished_call = RunCall(r.value().get());
324   EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
325             GRPC_STATUS_INTERNAL);
326   EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
327                 ->as_string_view(),
328             "�� cancelled");
329   EXPECT_NE(finished_call.client_metadata, nullptr);
330 }
331 
TEST_F(InterceptionChainTest,Consumed)332 TEST_F(InterceptionChainTest, Consumed) {
333   auto r = InterceptionChainBuilder(ChannelArgs())
334                .Add<TestConsumingInterceptor<1>>()
335                .Build(destination());
336   ASSERT_TRUE(r.ok()) << r.status();
337   auto finished_call = RunCall(r.value().get());
338   EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
339             GRPC_STATUS_INTERNAL);
340   EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
341                 ->as_string_view(),
342             "�� consumed");
343   EXPECT_EQ(finished_call.client_metadata, nullptr);
344 }
345 
TEST_F(InterceptionChainTest,Hijacked)346 TEST_F(InterceptionChainTest, Hijacked) {
347   auto r = InterceptionChainBuilder(ChannelArgs())
348                .Add<TestHijackingInterceptor<1>>()
349                .Build(destination());
350   ASSERT_TRUE(r.ok()) << r.status();
351   auto finished_call = RunCall(r.value().get());
352   EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
353             GRPC_STATUS_INTERNAL);
354   EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
355                 ->as_string_view(),
356             "�� cancelled");
357   EXPECT_NE(finished_call.client_metadata, nullptr);
358 }
359 
TEST_F(InterceptionChainTest,FiltersThenHijacked)360 TEST_F(InterceptionChainTest, FiltersThenHijacked) {
361   auto r = InterceptionChainBuilder(ChannelArgs())
362                .Add<TestFilter<1>>()
363                .Add<TestHijackingInterceptor<2>>()
364                .Build(destination());
365   ASSERT_TRUE(r.ok()) << r.status();
366   auto finished_call = RunCall(r.value().get());
367   EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
368             GRPC_STATUS_INTERNAL);
369   EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
370                 ->as_string_view(),
371             "�� cancelled");
372   EXPECT_NE(finished_call.client_metadata, nullptr);
373   std::string backing;
374   EXPECT_EQ(finished_call.client_metadata->GetStringValue("passed-through-1",
375                                                           &backing),
376             "true");
377 }
378 
TEST_F(InterceptionChainTest,FailsToInstantiateInterceptor)379 TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor) {
380   auto r = InterceptionChainBuilder(ChannelArgs())
381                .Add<TestFailingInterceptor<1>>()
382                .Build(destination());
383   EXPECT_FALSE(r.ok());
384   EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal);
385   EXPECT_EQ(r.status().message(), "�� failed to instantiate 1");
386 }
387 
TEST_F(InterceptionChainTest,FailsToInstantiateInterceptor2)388 TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor2) {
389   auto r = InterceptionChainBuilder(ChannelArgs())
390                .Add<TestFilter<1>>()
391                .Add<TestFailingInterceptor<2>>()
392                .Build(destination());
393   EXPECT_FALSE(r.ok());
394   EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal);
395   EXPECT_EQ(r.status().message(), "�� failed to instantiate 2");
396 }
397 
TEST_F(InterceptionChainTest,FailsToInstantiateFilter)398 TEST_F(InterceptionChainTest, FailsToInstantiateFilter) {
399   auto r = InterceptionChainBuilder(ChannelArgs())
400                .Add<FailsToInstantiateFilter<1>>()
401                .Build(destination());
402   EXPECT_FALSE(r.ok());
403   EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal);
404   EXPECT_EQ(r.status().message(), "�� failed to instantiate 1");
405 }
406 
TEST_F(InterceptionChainTest,FailsToInstantiateFilter2)407 TEST_F(InterceptionChainTest, FailsToInstantiateFilter2) {
408   auto r = InterceptionChainBuilder(ChannelArgs())
409                .Add<TestFilter<1>>()
410                .Add<FailsToInstantiateFilter<2>>()
411                .Build(destination());
412   EXPECT_FALSE(r.ok());
413   EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal);
414   EXPECT_EQ(r.status().message(), "�� failed to instantiate 2");
415 }
416 
TEST_F(InterceptionChainTest,CreationOrderCorrect)417 TEST_F(InterceptionChainTest, CreationOrderCorrect) {
418   CreationLog log;
419   auto r = InterceptionChainBuilder(ChannelArgs().SetObject(&log))
420                .Add<TestFilter<1>>()
421                .Add<TestFilter<2>>()
422                .Add<TestFilter<3>>()
423                .Add<TestConsumingInterceptor<4>>()
424                .Add<TestFilter<1>>()
425                .Add<TestFilter<2>>()
426                .Add<TestFilter<3>>()
427                .Add<TestConsumingInterceptor<4>>()
428                .Add<TestFilter<1>>()
429                .Build(destination());
430   EXPECT_THAT(log.entries, ::testing::ElementsAre(
431                                CreationLogEntry{0, 1}, CreationLogEntry{0, 2},
432                                CreationLogEntry{0, 3}, CreationLogEntry{0, 4},
433                                CreationLogEntry{1, 1}, CreationLogEntry{1, 2},
434                                CreationLogEntry{1, 3}, CreationLogEntry{1, 4},
435                                CreationLogEntry{2, 1}));
436 }
437 
438 }  // namespace
439 }  // namespace grpc_core
440 
main(int argc,char ** argv)441 int main(int argc, char** argv) {
442   ::testing::InitGoogleTest(&argc, argv);
443   grpc_tracer_init();
444   gpr_log_verbosity_init();
445   grpc_init();
446   auto r = RUN_ALL_TESTS();
447   grpc_shutdown();
448   return r;
449 }
450