1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/core/lib/surface/channel_init.h"
16
17 #include <map>
18 #include <memory>
19 #include <string>
20
21 #include "absl/strings/string_view.h"
22 #include "gtest/gtest.h"
23 #include "src/core/lib/channel/channel_stack.h"
24 #include "src/core/lib/channel/channel_stack_builder_impl.h"
25 #include "src/core/lib/channel/promise_based_filter.h"
26 #include "src/core/lib/resource_quota/resource_quota.h"
27 #include "src/core/lib/surface/channel_stack_type.h"
28 #include "src/core/lib/transport/call_arena_allocator.h"
29 #include "test/core/test_util/test_config.h"
30
31 namespace grpc_core {
32 namespace {
33
FilterNamed(const char * name)34 const grpc_channel_filter* FilterNamed(const char* name) {
35 static auto* filters =
36 new std::map<absl::string_view, const grpc_channel_filter*>;
37 auto it = filters->find(name);
38 if (it != filters->end()) return it->second;
39 static auto* name_factories =
40 new std::vector<std::unique_ptr<UniqueTypeName::Factory>>();
41 name_factories->emplace_back(std::make_unique<UniqueTypeName::Factory>(name));
42 auto unique_type_name = name_factories->back()->Create();
43 return filters
44 ->emplace(name,
45 new grpc_channel_filter{nullptr, nullptr, 0, nullptr, nullptr,
46 nullptr, 0, nullptr, nullptr, nullptr,
47 nullptr, unique_type_name})
48 .first->second;
49 }
50
GetFilterNames(const ChannelInit & init,grpc_channel_stack_type type,const ChannelArgs & args)51 std::vector<std::string> GetFilterNames(const ChannelInit& init,
52 grpc_channel_stack_type type,
53 const ChannelArgs& args) {
54 ChannelStackBuilderImpl b("test", type, args);
55 if (!init.CreateStack(&b)) return {};
56 std::vector<std::string> names;
57 for (auto f : b.stack()) {
58 names.push_back(std::string(f->name.name()));
59 }
60 EXPECT_NE(names, std::vector<std::string>());
61 return names;
62 }
63
TEST(ChannelInitTest,Empty)64 TEST(ChannelInitTest, Empty) {
65 ChannelInit::Builder b;
66 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
67 auto init = b.Build();
68 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
69 std::vector<std::string>({"terminator"}));
70 }
71
TEST(ChannelInitTest,OneClientFilter)72 TEST(ChannelInitTest, OneClientFilter) {
73 ChannelInit::Builder b;
74 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
75 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
76 b.RegisterFilter(GRPC_SERVER_CHANNEL, FilterNamed("terminator")).Terminal();
77 auto init = b.Build();
78 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
79 std::vector<std::string>({"foo", "terminator"}));
80 EXPECT_EQ(GetFilterNames(init, GRPC_SERVER_CHANNEL, ChannelArgs()),
81 std::vector<std::string>({"terminator"}));
82 }
83
TEST(ChannelInitTest,DefaultLexicalOrdering)84 TEST(ChannelInitTest, DefaultLexicalOrdering) {
85 // ChannelInit defaults to lexical ordering in the absence of other
86 // constraints, to ensure that a stable ordering is produced between builds.
87 ChannelInit::Builder b;
88 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
89 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
90 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
91 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
92 auto init = b.Build();
93 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
94 std::vector<std::string>({"bar", "baz", "foo", "aaa"}));
95 }
96
TEST(ChannelInitTest,AfterConstraintsApply)97 TEST(ChannelInitTest, AfterConstraintsApply) {
98 ChannelInit::Builder b;
99 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
100 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
101 .After({FilterNamed("foo")->name});
102 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
103 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
104 auto init = b.Build();
105 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
106 std::vector<std::string>({"baz", "foo", "bar", "aaa"}));
107 }
108
TEST(ChannelInitTest,BeforeConstraintsApply)109 TEST(ChannelInitTest, BeforeConstraintsApply) {
110 ChannelInit::Builder b;
111 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"))
112 .Before({FilterNamed("bar")->name});
113 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
114 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
115 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
116 auto init = b.Build();
117 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
118 std::vector<std::string>({"baz", "foo", "bar", "aaa"}));
119 }
120
TEST(ChannelInitTest,PredicatesCanFilter)121 TEST(ChannelInitTest, PredicatesCanFilter) {
122 ChannelInit::Builder b;
123 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"))
124 .IfChannelArg("foo", true);
125 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
126 .IfChannelArg("bar", false);
127 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
128 auto init = b.Build();
129 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
130 std::vector<std::string>({"foo", "aaa"}));
131 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
132 ChannelArgs().Set("foo", false)),
133 std::vector<std::string>({"aaa"}));
134 EXPECT_EQ(
135 GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("bar", true)),
136 std::vector<std::string>({"bar", "foo", "aaa"}));
137 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
138 ChannelArgs().Set("bar", true).Set("foo", false)),
139 std::vector<std::string>({"bar", "aaa"}));
140 }
141
TEST(ChannelInitTest,CanAddTerminalFilter)142 TEST(ChannelInitTest, CanAddTerminalFilter) {
143 ChannelInit::Builder b;
144 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
145 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar")).Terminal();
146 auto init = b.Build();
147 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
148 std::vector<std::string>({"foo", "bar"}));
149 }
150
TEST(ChannelInitTest,CanAddMultipleTerminalFilters)151 TEST(ChannelInitTest, CanAddMultipleTerminalFilters) {
152 ChannelInit::Builder b;
153 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
154 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
155 .Terminal()
156 .IfChannelArg("bar", false);
157 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"))
158 .Terminal()
159 .IfChannelArg("baz", false);
160 auto init = b.Build();
161 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
162 std::vector<std::string>());
163 EXPECT_EQ(
164 GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("bar", true)),
165 std::vector<std::string>({"foo", "bar"}));
166 EXPECT_EQ(
167 GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("baz", true)),
168 std::vector<std::string>({"foo", "baz"}));
169 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
170 ChannelArgs().Set("bar", true).Set("baz", true)),
171 std::vector<std::string>());
172 }
173
TEST(ChannelInitTest,CanAddBeforeAllOnce)174 TEST(ChannelInitTest, CanAddBeforeAllOnce) {
175 ChannelInit::Builder b;
176 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo")).BeforeAll();
177 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
178 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
179 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
180 EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
181 std::vector<std::string>({"foo", "bar", "baz", "aaa"}));
182 }
183
TEST(ChannelInitDeathTest,CanAddBeforeAllTwice)184 TEST(ChannelInitDeathTest, CanAddBeforeAllTwice) {
185 GTEST_FLAG_SET(death_test_style, "threadsafe");
186 ChannelInit::Builder b;
187 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo")).BeforeAll();
188 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar")).BeforeAll();
189 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
190 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
191 EXPECT_DEATH_IF_SUPPORTED(b.Build(), "Unresolvable graph of channel filters");
192 }
193
TEST(ChannelInitTest,CanPostProcessFilters)194 TEST(ChannelInitTest, CanPostProcessFilters) {
195 ChannelInit::Builder b;
196 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
197 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
198 int called_post_processor = 0;
199 b.RegisterPostProcessor(
200 GRPC_CLIENT_CHANNEL,
201 ChannelInit::PostProcessorSlot::kXdsChannelStackModifier,
202 [&called_post_processor](ChannelStackBuilder& b) {
203 ++called_post_processor;
204 b.mutable_stack()->push_back(FilterNamed("bar"));
205 });
206 auto init = b.Build();
207 EXPECT_EQ(called_post_processor, 0);
208 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
209 std::vector<std::string>({"foo", "aaa", "bar"}));
210 }
211
TEST(ChannelInitTest,OrderingConstraintsAreSatisfied)212 TEST(ChannelInitTest, OrderingConstraintsAreSatisfied) {
213 ChannelInit::Builder b;
214 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).FloatToTop();
215 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b"));
216 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("a")).SinkToBottom();
217 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
218 EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
219 std::vector<std::string>({"c", "b", "a", "terminator"}));
220 }
221
TEST(ChannelInitTest,AmbiguousTopCrashes)222 TEST(ChannelInitTest, AmbiguousTopCrashes) {
223 ChannelInit::Builder b;
224 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).FloatToTop();
225 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b")).FloatToTop();
226 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
227 EXPECT_DEATH_IF_SUPPORTED(b.Build(), "Ambiguous");
228 }
229
TEST(ChannelInitTest,ExplicitOrderingBetweenTopResolvesAmbiguity)230 TEST(ChannelInitTest, ExplicitOrderingBetweenTopResolvesAmbiguity) {
231 ChannelInit::Builder b;
232 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).FloatToTop();
233 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b"))
234 .FloatToTop()
235 .After({FilterNamed("c")->name});
236 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
237 EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
238 std::vector<std::string>({"c", "b", "terminator"}));
239 }
240
TEST(ChannelInitTest,AmbiguousBottomCrashes)241 TEST(ChannelInitTest, AmbiguousBottomCrashes) {
242 ChannelInit::Builder b;
243 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).SinkToBottom();
244 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b")).SinkToBottom();
245 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
246 EXPECT_DEATH_IF_SUPPORTED(b.Build(), "Ambiguous");
247 }
248
TEST(ChannelInitTest,ExplicitOrderingBetweenBottomResolvesAmbiguity)249 TEST(ChannelInitTest, ExplicitOrderingBetweenBottomResolvesAmbiguity) {
250 ChannelInit::Builder b;
251 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).SinkToBottom();
252 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b"))
253 .SinkToBottom()
254 .After({FilterNamed("c")->name});
255 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
256 EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
257 std::vector<std::string>({"c", "b", "terminator"}));
258 }
259
TEST(ChannelInitTest,BottomCanComeBeforeTopWithExplicitOrdering)260 TEST(ChannelInitTest, BottomCanComeBeforeTopWithExplicitOrdering) {
261 ChannelInit::Builder b;
262 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).FloatToTop();
263 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b"))
264 .SinkToBottom()
265 .Before({FilterNamed("c")->name});
266 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
267 EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
268 std::vector<std::string>({"b", "c", "terminator"}));
269 }
270
271 class TestFilter1 {
272 public:
TestFilter1(int * p)273 explicit TestFilter1(int* p) : p_(p) {}
274
TypeName()275 static absl::string_view TypeName() { return "TestFilter1"; }
276
Create(const ChannelArgs & args,ChannelFilter::Args)277 static absl::StatusOr<std::unique_ptr<TestFilter1>> Create(
278 const ChannelArgs& args, ChannelFilter::Args) {
279 EXPECT_EQ(args.GetInt("foo"), 1);
280 return std::make_unique<TestFilter1>(args.GetPointer<int>("p"));
281 }
282
283 static const grpc_channel_filter kFilter;
284
285 class Call {
286 public:
Call(TestFilter1 * filter)287 explicit Call(TestFilter1* filter) {
288 EXPECT_EQ(*filter->x_, 0);
289 *filter->x_ = 1;
290 ++*filter->p_;
291 }
292 static const NoInterceptor OnClientInitialMetadata;
293 static const NoInterceptor OnServerInitialMetadata;
294 static const NoInterceptor OnServerTrailingMetadata;
295 static const NoInterceptor OnClientToServerMessage;
296 static const NoInterceptor OnClientToServerHalfClose;
297 static const NoInterceptor OnServerToClientMessage;
298 static const NoInterceptor OnFinalize;
299 };
300
301 private:
302 std::unique_ptr<int> x_ = std::make_unique<int>(0);
303 int* const p_;
304 };
305
306 const grpc_channel_filter TestFilter1::kFilter = {
307 nullptr, nullptr, 0, nullptr,
308 nullptr, nullptr, 0, nullptr,
309 nullptr, nullptr, nullptr, GRPC_UNIQUE_TYPE_NAME_HERE("test_filter1")};
310 const NoInterceptor TestFilter1::Call::OnClientInitialMetadata;
311 const NoInterceptor TestFilter1::Call::OnServerInitialMetadata;
312 const NoInterceptor TestFilter1::Call::OnServerTrailingMetadata;
313 const NoInterceptor TestFilter1::Call::OnClientToServerMessage;
314 const NoInterceptor TestFilter1::Call::OnClientToServerHalfClose;
315 const NoInterceptor TestFilter1::Call::OnServerToClientMessage;
316 const NoInterceptor TestFilter1::Call::OnFinalize;
317
TEST(ChannelInitTest,CanCreateFilterWithCall)318 TEST(ChannelInitTest, CanCreateFilterWithCall) {
319 grpc::testing::TestGrpcScope g;
320 ChannelInit::Builder b;
321 b.RegisterFilter<TestFilter1>(GRPC_CLIENT_CHANNEL);
322 auto init = b.Build();
323 int p = 0;
324 InterceptionChainBuilder chain_builder{
325 ChannelArgs().Set("foo", 1).Set("p", ChannelArgs::UnownedPointer(&p))};
326 init.AddToInterceptionChainBuilder(GRPC_CLIENT_CHANNEL, chain_builder);
327 int handled = 0;
328 auto stack = chain_builder.Build(MakeCallDestinationFromHandlerFunction(
329 [&handled](CallHandler) { ++handled; }));
330 ASSERT_TRUE(stack.ok()) << stack.status();
331 RefCountedPtr<CallArenaAllocator> allocator =
332 MakeRefCounted<CallArenaAllocator>(
333 ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
334 "test"),
335 1024);
336 auto event_engine = grpc_event_engine::experimental::GetDefaultEventEngine();
337 auto arena = allocator->MakeArena();
338 arena->SetContext<grpc_event_engine::experimental::EventEngine>(
339 event_engine.get());
340 auto call = MakeCallPair(Arena::MakePooledForOverwrite<ClientMetadata>(),
341 std::move(arena));
342 (*stack)->StartCall(std::move(call.handler));
343 EXPECT_EQ(p, 1);
344 EXPECT_EQ(handled, 1);
345 }
346
347 } // namespace
348 } // namespace grpc_core
349
main(int argc,char ** argv)350 int main(int argc, char** argv) {
351 grpc::testing::TestEnvironment env(&argc, argv);
352 ::testing::InitGoogleTest(&argc, argv);
353 return RUN_ALL_TESTS();
354 }
355