• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/core/lib/surface/channel_init.h"
16 
17 #include <map>
18 #include <memory>
19 #include <string>
20 
21 #include "absl/strings/string_view.h"
22 #include "gtest/gtest.h"
23 #include "src/core/lib/channel/channel_stack.h"
24 #include "src/core/lib/channel/channel_stack_builder_impl.h"
25 #include "src/core/lib/channel/promise_based_filter.h"
26 #include "src/core/lib/resource_quota/resource_quota.h"
27 #include "src/core/lib/surface/channel_stack_type.h"
28 #include "src/core/lib/transport/call_arena_allocator.h"
29 #include "test/core/test_util/test_config.h"
30 
31 namespace grpc_core {
32 namespace {
33 
FilterNamed(const char * name)34 const grpc_channel_filter* FilterNamed(const char* name) {
35   static auto* filters =
36       new std::map<absl::string_view, const grpc_channel_filter*>;
37   auto it = filters->find(name);
38   if (it != filters->end()) return it->second;
39   static auto* name_factories =
40       new std::vector<std::unique_ptr<UniqueTypeName::Factory>>();
41   name_factories->emplace_back(std::make_unique<UniqueTypeName::Factory>(name));
42   auto unique_type_name = name_factories->back()->Create();
43   return filters
44       ->emplace(name,
45                 new grpc_channel_filter{nullptr, nullptr, 0, nullptr, nullptr,
46                                         nullptr, 0, nullptr, nullptr, nullptr,
47                                         nullptr, unique_type_name})
48       .first->second;
49 }
50 
GetFilterNames(const ChannelInit & init,grpc_channel_stack_type type,const ChannelArgs & args)51 std::vector<std::string> GetFilterNames(const ChannelInit& init,
52                                         grpc_channel_stack_type type,
53                                         const ChannelArgs& args) {
54   ChannelStackBuilderImpl b("test", type, args);
55   if (!init.CreateStack(&b)) return {};
56   std::vector<std::string> names;
57   for (auto f : b.stack()) {
58     names.push_back(std::string(f->name.name()));
59   }
60   EXPECT_NE(names, std::vector<std::string>());
61   return names;
62 }
63 
TEST(ChannelInitTest,Empty)64 TEST(ChannelInitTest, Empty) {
65   ChannelInit::Builder b;
66   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
67   auto init = b.Build();
68   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
69             std::vector<std::string>({"terminator"}));
70 }
71 
TEST(ChannelInitTest,OneClientFilter)72 TEST(ChannelInitTest, OneClientFilter) {
73   ChannelInit::Builder b;
74   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
75   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
76   b.RegisterFilter(GRPC_SERVER_CHANNEL, FilterNamed("terminator")).Terminal();
77   auto init = b.Build();
78   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
79             std::vector<std::string>({"foo", "terminator"}));
80   EXPECT_EQ(GetFilterNames(init, GRPC_SERVER_CHANNEL, ChannelArgs()),
81             std::vector<std::string>({"terminator"}));
82 }
83 
TEST(ChannelInitTest,DefaultLexicalOrdering)84 TEST(ChannelInitTest, DefaultLexicalOrdering) {
85   // ChannelInit defaults to lexical ordering in the absence of other
86   // constraints, to ensure that a stable ordering is produced between builds.
87   ChannelInit::Builder b;
88   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
89   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
90   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
91   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
92   auto init = b.Build();
93   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
94             std::vector<std::string>({"bar", "baz", "foo", "aaa"}));
95 }
96 
TEST(ChannelInitTest,AfterConstraintsApply)97 TEST(ChannelInitTest, AfterConstraintsApply) {
98   ChannelInit::Builder b;
99   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
100   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
101       .After({FilterNamed("foo")->name});
102   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
103   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
104   auto init = b.Build();
105   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
106             std::vector<std::string>({"baz", "foo", "bar", "aaa"}));
107 }
108 
TEST(ChannelInitTest,BeforeConstraintsApply)109 TEST(ChannelInitTest, BeforeConstraintsApply) {
110   ChannelInit::Builder b;
111   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"))
112       .Before({FilterNamed("bar")->name});
113   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
114   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
115   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
116   auto init = b.Build();
117   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
118             std::vector<std::string>({"baz", "foo", "bar", "aaa"}));
119 }
120 
TEST(ChannelInitTest,PredicatesCanFilter)121 TEST(ChannelInitTest, PredicatesCanFilter) {
122   ChannelInit::Builder b;
123   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"))
124       .IfChannelArg("foo", true);
125   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
126       .IfChannelArg("bar", false);
127   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
128   auto init = b.Build();
129   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
130             std::vector<std::string>({"foo", "aaa"}));
131   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
132                            ChannelArgs().Set("foo", false)),
133             std::vector<std::string>({"aaa"}));
134   EXPECT_EQ(
135       GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("bar", true)),
136       std::vector<std::string>({"bar", "foo", "aaa"}));
137   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
138                            ChannelArgs().Set("bar", true).Set("foo", false)),
139             std::vector<std::string>({"bar", "aaa"}));
140 }
141 
TEST(ChannelInitTest,CanAddTerminalFilter)142 TEST(ChannelInitTest, CanAddTerminalFilter) {
143   ChannelInit::Builder b;
144   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
145   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar")).Terminal();
146   auto init = b.Build();
147   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
148             std::vector<std::string>({"foo", "bar"}));
149 }
150 
TEST(ChannelInitTest,CanAddMultipleTerminalFilters)151 TEST(ChannelInitTest, CanAddMultipleTerminalFilters) {
152   ChannelInit::Builder b;
153   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
154   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
155       .Terminal()
156       .IfChannelArg("bar", false);
157   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"))
158       .Terminal()
159       .IfChannelArg("baz", false);
160   auto init = b.Build();
161   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
162             std::vector<std::string>());
163   EXPECT_EQ(
164       GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("bar", true)),
165       std::vector<std::string>({"foo", "bar"}));
166   EXPECT_EQ(
167       GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("baz", true)),
168       std::vector<std::string>({"foo", "baz"}));
169   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
170                            ChannelArgs().Set("bar", true).Set("baz", true)),
171             std::vector<std::string>());
172 }
173 
TEST(ChannelInitTest,CanAddBeforeAllOnce)174 TEST(ChannelInitTest, CanAddBeforeAllOnce) {
175   ChannelInit::Builder b;
176   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo")).BeforeAll();
177   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
178   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
179   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
180   EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
181             std::vector<std::string>({"foo", "bar", "baz", "aaa"}));
182 }
183 
TEST(ChannelInitDeathTest,CanAddBeforeAllTwice)184 TEST(ChannelInitDeathTest, CanAddBeforeAllTwice) {
185   GTEST_FLAG_SET(death_test_style, "threadsafe");
186   ChannelInit::Builder b;
187   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo")).BeforeAll();
188   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar")).BeforeAll();
189   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
190   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
191   EXPECT_DEATH_IF_SUPPORTED(b.Build(), "Unresolvable graph of channel filters");
192 }
193 
TEST(ChannelInitTest,CanPostProcessFilters)194 TEST(ChannelInitTest, CanPostProcessFilters) {
195   ChannelInit::Builder b;
196   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
197   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
198   int called_post_processor = 0;
199   b.RegisterPostProcessor(
200       GRPC_CLIENT_CHANNEL,
201       ChannelInit::PostProcessorSlot::kXdsChannelStackModifier,
202       [&called_post_processor](ChannelStackBuilder& b) {
203         ++called_post_processor;
204         b.mutable_stack()->push_back(FilterNamed("bar"));
205       });
206   auto init = b.Build();
207   EXPECT_EQ(called_post_processor, 0);
208   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
209             std::vector<std::string>({"foo", "aaa", "bar"}));
210 }
211 
TEST(ChannelInitTest,OrderingConstraintsAreSatisfied)212 TEST(ChannelInitTest, OrderingConstraintsAreSatisfied) {
213   ChannelInit::Builder b;
214   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).FloatToTop();
215   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b"));
216   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("a")).SinkToBottom();
217   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
218   EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
219             std::vector<std::string>({"c", "b", "a", "terminator"}));
220 }
221 
TEST(ChannelInitTest,AmbiguousTopCrashes)222 TEST(ChannelInitTest, AmbiguousTopCrashes) {
223   ChannelInit::Builder b;
224   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).FloatToTop();
225   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b")).FloatToTop();
226   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
227   EXPECT_DEATH_IF_SUPPORTED(b.Build(), "Ambiguous");
228 }
229 
TEST(ChannelInitTest,ExplicitOrderingBetweenTopResolvesAmbiguity)230 TEST(ChannelInitTest, ExplicitOrderingBetweenTopResolvesAmbiguity) {
231   ChannelInit::Builder b;
232   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).FloatToTop();
233   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b"))
234       .FloatToTop()
235       .After({FilterNamed("c")->name});
236   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
237   EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
238             std::vector<std::string>({"c", "b", "terminator"}));
239 }
240 
TEST(ChannelInitTest,AmbiguousBottomCrashes)241 TEST(ChannelInitTest, AmbiguousBottomCrashes) {
242   ChannelInit::Builder b;
243   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).SinkToBottom();
244   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b")).SinkToBottom();
245   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
246   EXPECT_DEATH_IF_SUPPORTED(b.Build(), "Ambiguous");
247 }
248 
TEST(ChannelInitTest,ExplicitOrderingBetweenBottomResolvesAmbiguity)249 TEST(ChannelInitTest, ExplicitOrderingBetweenBottomResolvesAmbiguity) {
250   ChannelInit::Builder b;
251   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).SinkToBottom();
252   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b"))
253       .SinkToBottom()
254       .After({FilterNamed("c")->name});
255   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
256   EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
257             std::vector<std::string>({"c", "b", "terminator"}));
258 }
259 
TEST(ChannelInitTest,BottomCanComeBeforeTopWithExplicitOrdering)260 TEST(ChannelInitTest, BottomCanComeBeforeTopWithExplicitOrdering) {
261   ChannelInit::Builder b;
262   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("c")).FloatToTop();
263   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("b"))
264       .SinkToBottom()
265       .Before({FilterNamed("c")->name});
266   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
267   EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
268             std::vector<std::string>({"b", "c", "terminator"}));
269 }
270 
271 class TestFilter1 {
272  public:
TestFilter1(int * p)273   explicit TestFilter1(int* p) : p_(p) {}
274 
TypeName()275   static absl::string_view TypeName() { return "TestFilter1"; }
276 
Create(const ChannelArgs & args,ChannelFilter::Args)277   static absl::StatusOr<std::unique_ptr<TestFilter1>> Create(
278       const ChannelArgs& args, ChannelFilter::Args) {
279     EXPECT_EQ(args.GetInt("foo"), 1);
280     return std::make_unique<TestFilter1>(args.GetPointer<int>("p"));
281   }
282 
283   static const grpc_channel_filter kFilter;
284 
285   class Call {
286    public:
Call(TestFilter1 * filter)287     explicit Call(TestFilter1* filter) {
288       EXPECT_EQ(*filter->x_, 0);
289       *filter->x_ = 1;
290       ++*filter->p_;
291     }
292     static const NoInterceptor OnClientInitialMetadata;
293     static const NoInterceptor OnServerInitialMetadata;
294     static const NoInterceptor OnServerTrailingMetadata;
295     static const NoInterceptor OnClientToServerMessage;
296     static const NoInterceptor OnClientToServerHalfClose;
297     static const NoInterceptor OnServerToClientMessage;
298     static const NoInterceptor OnFinalize;
299   };
300 
301  private:
302   std::unique_ptr<int> x_ = std::make_unique<int>(0);
303   int* const p_;
304 };
305 
306 const grpc_channel_filter TestFilter1::kFilter = {
307     nullptr, nullptr, 0,       nullptr,
308     nullptr, nullptr, 0,       nullptr,
309     nullptr, nullptr, nullptr, GRPC_UNIQUE_TYPE_NAME_HERE("test_filter1")};
310 const NoInterceptor TestFilter1::Call::OnClientInitialMetadata;
311 const NoInterceptor TestFilter1::Call::OnServerInitialMetadata;
312 const NoInterceptor TestFilter1::Call::OnServerTrailingMetadata;
313 const NoInterceptor TestFilter1::Call::OnClientToServerMessage;
314 const NoInterceptor TestFilter1::Call::OnClientToServerHalfClose;
315 const NoInterceptor TestFilter1::Call::OnServerToClientMessage;
316 const NoInterceptor TestFilter1::Call::OnFinalize;
317 
TEST(ChannelInitTest,CanCreateFilterWithCall)318 TEST(ChannelInitTest, CanCreateFilterWithCall) {
319   grpc::testing::TestGrpcScope g;
320   ChannelInit::Builder b;
321   b.RegisterFilter<TestFilter1>(GRPC_CLIENT_CHANNEL);
322   auto init = b.Build();
323   int p = 0;
324   InterceptionChainBuilder chain_builder{
325       ChannelArgs().Set("foo", 1).Set("p", ChannelArgs::UnownedPointer(&p))};
326   init.AddToInterceptionChainBuilder(GRPC_CLIENT_CHANNEL, chain_builder);
327   int handled = 0;
328   auto stack = chain_builder.Build(MakeCallDestinationFromHandlerFunction(
329       [&handled](CallHandler) { ++handled; }));
330   ASSERT_TRUE(stack.ok()) << stack.status();
331   RefCountedPtr<CallArenaAllocator> allocator =
332       MakeRefCounted<CallArenaAllocator>(
333           ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
334               "test"),
335           1024);
336   auto event_engine = grpc_event_engine::experimental::GetDefaultEventEngine();
337   auto arena = allocator->MakeArena();
338   arena->SetContext<grpc_event_engine::experimental::EventEngine>(
339       event_engine.get());
340   auto call = MakeCallPair(Arena::MakePooledForOverwrite<ClientMetadata>(),
341                            std::move(arena));
342   (*stack)->StartCall(std::move(call.handler));
343   EXPECT_EQ(p, 1);
344   EXPECT_EQ(handled, 1);
345 }
346 
347 }  // namespace
348 }  // namespace grpc_core
349 
main(int argc,char ** argv)350 int main(int argc, char** argv) {
351   grpc::testing::TestEnvironment env(&argc, argv);
352   ::testing::InitGoogleTest(&argc, argv);
353   return RUN_ALL_TESTS();
354 }
355