1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/core/lib/transport/interception_chain.h"
16
17 #include <grpc/grpc.h>
18 #include <grpc/support/log.h>
19
20 #include <memory>
21
22 #include "absl/log/log.h"
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25 #include "src/core/lib/channel/promise_based_filter.h"
26 #include "src/core/lib/resource_quota/resource_quota.h"
27 #include "test/core/promise/poll_matcher.h"
28
29 namespace grpc_core {
30 namespace {
31
32 ///////////////////////////////////////////////////////////////////////////////
33 // Mutate metadata by annotating that it passed through a filter "x"
34
AnnotatePassedThrough(ClientMetadata & md,int x)35 void AnnotatePassedThrough(ClientMetadata& md, int x) {
36 md.Append(absl::StrCat("passed-through-", x), Slice::FromCopiedString("true"),
37 [](absl::string_view, const Slice&) { Crash("unreachable"); });
38 }
39
40 ///////////////////////////////////////////////////////////////////////////////
41 // CreationLog helps us reason about filter creation order by logging a small
42 // record of each filter's creation.
43
44 struct CreationLogEntry {
45 size_t filter_instance_id;
46 size_t type_tag;
47
operator ==grpc_core::__anon9d967e2b0111::CreationLogEntry48 bool operator==(const CreationLogEntry& other) const {
49 return filter_instance_id == other.filter_instance_id &&
50 type_tag == other.type_tag;
51 }
52
operator <<(std::ostream & os,const CreationLogEntry & entry)53 friend std::ostream& operator<<(std::ostream& os,
54 const CreationLogEntry& entry) {
55 return os << "{filter_instance_id=" << entry.filter_instance_id
56 << ", type_tag=" << entry.type_tag << "}";
57 }
58 };
59
60 struct CreationLog {
61 struct RawPointerChannelArgTag {};
ChannelArgNamegrpc_core::__anon9d967e2b0111::CreationLog62 static absl::string_view ChannelArgName() { return "creation_log"; }
63 std::vector<CreationLogEntry> entries;
64 };
65
MaybeLogCreation(const ChannelArgs & channel_args,ChannelFilter::Args filter_args,size_t type_tag)66 void MaybeLogCreation(const ChannelArgs& channel_args,
67 ChannelFilter::Args filter_args, size_t type_tag) {
68 auto* log = channel_args.GetPointer<CreationLog>("creation_log");
69 if (log == nullptr) return;
70 log->entries.push_back(CreationLogEntry{filter_args.instance_id(), type_tag});
71 }
72
73 ///////////////////////////////////////////////////////////////////////////////
74 // Test call filter
75
76 template <int I>
77 class TestFilter {
78 public:
79 class Call {
80 public:
OnClientInitialMetadata(ClientMetadata & md)81 void OnClientInitialMetadata(ClientMetadata& md) {
82 AnnotatePassedThrough(md, I);
83 }
84 static const NoInterceptor OnServerInitialMetadata;
85 static const NoInterceptor OnClientToServerMessage;
86 static const NoInterceptor OnClientToServerHalfClose;
87 static const NoInterceptor OnServerToClientMessage;
88 static const NoInterceptor OnServerTrailingMetadata;
89 static const NoInterceptor OnFinalize;
90 };
91
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)92 static absl::StatusOr<std::unique_ptr<TestFilter<I>>> Create(
93 const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
94 MaybeLogCreation(channel_args, filter_args, I);
95 return std::make_unique<TestFilter<I>>();
96 }
97
98 private:
99 std::unique_ptr<int> i_ = std::make_unique<int>(I);
100 };
101
102 template <int I>
103 const NoInterceptor TestFilter<I>::Call::OnServerInitialMetadata;
104 template <int I>
105 const NoInterceptor TestFilter<I>::Call::OnClientToServerMessage;
106 template <int I>
107 const NoInterceptor TestFilter<I>::Call::OnClientToServerHalfClose;
108 template <int I>
109 const NoInterceptor TestFilter<I>::Call::OnServerToClientMessage;
110 template <int I>
111 const NoInterceptor TestFilter<I>::Call::OnServerTrailingMetadata;
112 template <int I>
113 const NoInterceptor TestFilter<I>::Call::OnFinalize;
114
115 ///////////////////////////////////////////////////////////////////////////////
116 // Test call filter that fails to instantiate
117
118 template <int I>
119 class FailsToInstantiateFilter {
120 public:
121 class Call {
122 public:
123 static const NoInterceptor OnClientInitialMetadata;
124 static const NoInterceptor OnServerInitialMetadata;
125 static const NoInterceptor OnClientToServerMessage;
126 static const NoInterceptor OnClientToServerHalfClose;
127 static const NoInterceptor OnServerToClientMessage;
128 static const NoInterceptor OnServerTrailingMetadata;
129 static const NoInterceptor OnFinalize;
130 };
131
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)132 static absl::StatusOr<std::unique_ptr<FailsToInstantiateFilter<I>>> Create(
133 const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
134 MaybeLogCreation(channel_args, filter_args, I);
135 return absl::InternalError(absl::StrCat(" failed to instantiate ", I));
136 }
137 };
138
139 template <int I>
140 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientInitialMetadata;
141 template <int I>
142 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerInitialMetadata;
143 template <int I>
144 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientToServerMessage;
145 template <int I>
146 const NoInterceptor
147 FailsToInstantiateFilter<I>::Call::OnClientToServerHalfClose;
148 template <int I>
149 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerToClientMessage;
150 template <int I>
151 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerTrailingMetadata;
152 template <int I>
153 const NoInterceptor FailsToInstantiateFilter<I>::Call::OnFinalize;
154
155 ///////////////////////////////////////////////////////////////////////////////
156 // Test call interceptor - consumes calls
157
158 template <int I>
159 class TestConsumingInterceptor final : public Interceptor {
160 public:
InterceptCall(UnstartedCallHandler unstarted_call_handler)161 void InterceptCall(UnstartedCallHandler unstarted_call_handler) override {
162 Consume(std::move(unstarted_call_handler))
163 .PushServerTrailingMetadata(
164 ServerMetadataFromStatus(GRPC_STATUS_INTERNAL, " consumed"));
165 }
Orphaned()166 void Orphaned() override {}
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)167 static absl::StatusOr<RefCountedPtr<TestConsumingInterceptor<I>>> Create(
168 const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
169 MaybeLogCreation(channel_args, filter_args, I);
170 return MakeRefCounted<TestConsumingInterceptor<I>>();
171 }
172 };
173
174 ///////////////////////////////////////////////////////////////////////////////
175 // Test call interceptor - passes through calls
176
177 template <int I>
178 class TestPassThroughInterceptor final : public Interceptor {
179 public:
InterceptCall(UnstartedCallHandler unstarted_call_handler)180 void InterceptCall(UnstartedCallHandler unstarted_call_handler) override {
181 PassThrough(std::move(unstarted_call_handler));
182 }
Orphaned()183 void Orphaned() override {}
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)184 static absl::StatusOr<RefCountedPtr<TestPassThroughInterceptor<I>>> Create(
185 const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
186 MaybeLogCreation(channel_args, filter_args, I);
187 return MakeRefCounted<TestPassThroughInterceptor<I>>();
188 }
189 };
190
191 ///////////////////////////////////////////////////////////////////////////////
192 // Test call interceptor - fails to instantiate
193
194 template <int I>
195 class TestFailingInterceptor final : public Interceptor {
196 public:
InterceptCall(UnstartedCallHandler unstarted_call_handler)197 void InterceptCall(UnstartedCallHandler unstarted_call_handler) override {
198 Crash("unreachable");
199 }
Orphaned()200 void Orphaned() override {}
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)201 static absl::StatusOr<RefCountedPtr<TestFailingInterceptor<I>>> Create(
202 const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
203 MaybeLogCreation(channel_args, filter_args, I);
204 return absl::InternalError(absl::StrCat(" failed to instantiate ", I));
205 }
206 };
207
208 ///////////////////////////////////////////////////////////////////////////////
209 // Test call interceptor - hijacks calls
210
211 template <int I>
212 class TestHijackingInterceptor final : public Interceptor {
213 public:
InterceptCall(UnstartedCallHandler unstarted_call_handler)214 void InterceptCall(UnstartedCallHandler unstarted_call_handler) override {
215 unstarted_call_handler.SpawnInfallible(
216 "hijack", [this, unstarted_call_handler]() mutable {
217 return Map(Hijack(std::move(unstarted_call_handler)),
218 [](ValueOrFailure<HijackedCall> hijacked_call) {
219 ForwardCall(
220 hijacked_call.value().original_call_handler(),
221 hijacked_call.value().MakeCall());
222 });
223 });
224 }
Orphaned()225 void Orphaned() override {}
Create(const ChannelArgs & channel_args,ChannelFilter::Args filter_args)226 static absl::StatusOr<RefCountedPtr<TestHijackingInterceptor<I>>> Create(
227 const ChannelArgs& channel_args, ChannelFilter::Args filter_args) {
228 MaybeLogCreation(channel_args, filter_args, I);
229 return MakeRefCounted<TestHijackingInterceptor<I>>();
230 }
231 };
232
233 ///////////////////////////////////////////////////////////////////////////////
234 // Test fixture
235
236 class InterceptionChainTest : public ::testing::Test {
237 protected:
InterceptionChainTest()238 InterceptionChainTest() {}
~InterceptionChainTest()239 ~InterceptionChainTest() override {}
240
destination()241 RefCountedPtr<UnstartedCallDestination> destination() { return destination_; }
242
243 struct FinishedCall {
244 CallInitiator call;
245 ClientMetadataHandle client_metadata;
246 ServerMetadataHandle server_metadata;
247 };
248
249 // Run a call through a UnstartedCallDestination until it's complete.
RunCall(UnstartedCallDestination * destination)250 FinishedCall RunCall(UnstartedCallDestination* destination) {
251 auto arena = call_arena_allocator_->MakeArena();
252 arena->SetContext<grpc_event_engine::experimental::EventEngine>(
253 event_engine_.get());
254 auto call = MakeCallPair(Arena::MakePooledForOverwrite<ClientMetadata>(),
255 std::move(arena));
256 Poll<ServerMetadataHandle> trailing_md;
257 call.initiator.SpawnInfallible(
258 "run_call", [destination, &call, &trailing_md]() mutable {
259 LOG(INFO) << " start call";
260 destination->StartCall(std::move(call.handler));
261 return Map(call.initiator.PullServerTrailingMetadata(),
262 [&trailing_md](ServerMetadataHandle md) {
263 trailing_md = std::move(md);
264 });
265 });
266 EXPECT_THAT(trailing_md, IsReady());
267 return FinishedCall{std::move(call.initiator), destination_->TakeMetadata(),
268 std::move(trailing_md.value())};
269 }
270
271 private:
272 class Destination final : public UnstartedCallDestination {
273 public:
StartCall(UnstartedCallHandler unstarted_call_handler)274 void StartCall(UnstartedCallHandler unstarted_call_handler) override {
275 LOG(INFO) << " started call: metadata="
276 << unstarted_call_handler.UnprocessedClientInitialMetadata()
277 .DebugString();
278 EXPECT_EQ(metadata_.get(), nullptr);
279 metadata_ = Arena::MakePooledForOverwrite<ClientMetadata>();
280 *metadata_ =
281 unstarted_call_handler.UnprocessedClientInitialMetadata().Copy();
282 unstarted_call_handler.PushServerTrailingMetadata(
283 ServerMetadataFromStatus(GRPC_STATUS_INTERNAL, " cancelled"));
284 }
285
Orphaned()286 void Orphaned() override {}
287
TakeMetadata()288 ClientMetadataHandle TakeMetadata() { return std::move(metadata_); }
289
290 private:
291 ClientMetadataHandle metadata_;
292 };
293 std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_ =
294 grpc_event_engine::experimental::GetDefaultEventEngine();
295 RefCountedPtr<Destination> destination_ = MakeRefCounted<Destination>();
296 RefCountedPtr<CallArenaAllocator> call_arena_allocator_ =
297 MakeRefCounted<CallArenaAllocator>(
298 ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
299 "test"),
300 1024);
301 };
302
303 ///////////////////////////////////////////////////////////////////////////////
304 // Tests begin
305
TEST_F(InterceptionChainTest,Empty)306 TEST_F(InterceptionChainTest, Empty) {
307 auto r = InterceptionChainBuilder(ChannelArgs()).Build(destination());
308 ASSERT_TRUE(r.ok()) << r.status();
309 auto finished_call = RunCall(r.value().get());
310 EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
311 GRPC_STATUS_INTERNAL);
312 EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
313 ->as_string_view(),
314 " cancelled");
315 EXPECT_NE(finished_call.client_metadata, nullptr);
316 }
317
TEST_F(InterceptionChainTest,PassThrough)318 TEST_F(InterceptionChainTest, PassThrough) {
319 auto r = InterceptionChainBuilder(ChannelArgs())
320 .Add<TestPassThroughInterceptor<1>>()
321 .Build(destination());
322 ASSERT_TRUE(r.ok()) << r.status();
323 auto finished_call = RunCall(r.value().get());
324 EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
325 GRPC_STATUS_INTERNAL);
326 EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
327 ->as_string_view(),
328 " cancelled");
329 EXPECT_NE(finished_call.client_metadata, nullptr);
330 }
331
TEST_F(InterceptionChainTest,Consumed)332 TEST_F(InterceptionChainTest, Consumed) {
333 auto r = InterceptionChainBuilder(ChannelArgs())
334 .Add<TestConsumingInterceptor<1>>()
335 .Build(destination());
336 ASSERT_TRUE(r.ok()) << r.status();
337 auto finished_call = RunCall(r.value().get());
338 EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
339 GRPC_STATUS_INTERNAL);
340 EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
341 ->as_string_view(),
342 " consumed");
343 EXPECT_EQ(finished_call.client_metadata, nullptr);
344 }
345
TEST_F(InterceptionChainTest,Hijacked)346 TEST_F(InterceptionChainTest, Hijacked) {
347 auto r = InterceptionChainBuilder(ChannelArgs())
348 .Add<TestHijackingInterceptor<1>>()
349 .Build(destination());
350 ASSERT_TRUE(r.ok()) << r.status();
351 auto finished_call = RunCall(r.value().get());
352 EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
353 GRPC_STATUS_INTERNAL);
354 EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
355 ->as_string_view(),
356 " cancelled");
357 EXPECT_NE(finished_call.client_metadata, nullptr);
358 }
359
TEST_F(InterceptionChainTest,FiltersThenHijacked)360 TEST_F(InterceptionChainTest, FiltersThenHijacked) {
361 auto r = InterceptionChainBuilder(ChannelArgs())
362 .Add<TestFilter<1>>()
363 .Add<TestHijackingInterceptor<2>>()
364 .Build(destination());
365 ASSERT_TRUE(r.ok()) << r.status();
366 auto finished_call = RunCall(r.value().get());
367 EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
368 GRPC_STATUS_INTERNAL);
369 EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
370 ->as_string_view(),
371 " cancelled");
372 EXPECT_NE(finished_call.client_metadata, nullptr);
373 std::string backing;
374 EXPECT_EQ(finished_call.client_metadata->GetStringValue("passed-through-1",
375 &backing),
376 "true");
377 }
378
TEST_F(InterceptionChainTest,FailsToInstantiateInterceptor)379 TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor) {
380 auto r = InterceptionChainBuilder(ChannelArgs())
381 .Add<TestFailingInterceptor<1>>()
382 .Build(destination());
383 EXPECT_FALSE(r.ok());
384 EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal);
385 EXPECT_EQ(r.status().message(), " failed to instantiate 1");
386 }
387
TEST_F(InterceptionChainTest,FailsToInstantiateInterceptor2)388 TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor2) {
389 auto r = InterceptionChainBuilder(ChannelArgs())
390 .Add<TestFilter<1>>()
391 .Add<TestFailingInterceptor<2>>()
392 .Build(destination());
393 EXPECT_FALSE(r.ok());
394 EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal);
395 EXPECT_EQ(r.status().message(), " failed to instantiate 2");
396 }
397
TEST_F(InterceptionChainTest,FailsToInstantiateFilter)398 TEST_F(InterceptionChainTest, FailsToInstantiateFilter) {
399 auto r = InterceptionChainBuilder(ChannelArgs())
400 .Add<FailsToInstantiateFilter<1>>()
401 .Build(destination());
402 EXPECT_FALSE(r.ok());
403 EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal);
404 EXPECT_EQ(r.status().message(), " failed to instantiate 1");
405 }
406
TEST_F(InterceptionChainTest,FailsToInstantiateFilter2)407 TEST_F(InterceptionChainTest, FailsToInstantiateFilter2) {
408 auto r = InterceptionChainBuilder(ChannelArgs())
409 .Add<TestFilter<1>>()
410 .Add<FailsToInstantiateFilter<2>>()
411 .Build(destination());
412 EXPECT_FALSE(r.ok());
413 EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal);
414 EXPECT_EQ(r.status().message(), " failed to instantiate 2");
415 }
416
TEST_F(InterceptionChainTest,CreationOrderCorrect)417 TEST_F(InterceptionChainTest, CreationOrderCorrect) {
418 CreationLog log;
419 auto r = InterceptionChainBuilder(ChannelArgs().SetObject(&log))
420 .Add<TestFilter<1>>()
421 .Add<TestFilter<2>>()
422 .Add<TestFilter<3>>()
423 .Add<TestConsumingInterceptor<4>>()
424 .Add<TestFilter<1>>()
425 .Add<TestFilter<2>>()
426 .Add<TestFilter<3>>()
427 .Add<TestConsumingInterceptor<4>>()
428 .Add<TestFilter<1>>()
429 .Build(destination());
430 EXPECT_THAT(log.entries, ::testing::ElementsAre(
431 CreationLogEntry{0, 1}, CreationLogEntry{0, 2},
432 CreationLogEntry{0, 3}, CreationLogEntry{0, 4},
433 CreationLogEntry{1, 1}, CreationLogEntry{1, 2},
434 CreationLogEntry{1, 3}, CreationLogEntry{1, 4},
435 CreationLogEntry{2, 1}));
436 }
437
438 } // namespace
439 } // namespace grpc_core
440
main(int argc,char ** argv)441 int main(int argc, char** argv) {
442 ::testing::InitGoogleTest(&argc, argv);
443 grpc_tracer_init();
444 gpr_log_verbosity_init();
445 grpc_init();
446 auto r = RUN_ALL_TESTS();
447 grpc_shutdown();
448 return r;
449 }
450