• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test/core/filters/filter_test.h"
16 
17 #include <algorithm>
18 #include <chrono>
19 #include <memory>
20 #include <queue>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/types/optional.h"
26 #include "gtest/gtest.h"
27 
28 #include <grpc/grpc.h>
29 
30 #include "src/core/lib/channel/call_finalization.h"
31 #include "src/core/lib/channel/context.h"
32 #include "src/core/lib/event_engine/default_event_engine.h"
33 #include "src/core/lib/gprpp/crash.h"
34 #include "src/core/lib/iomgr/timer_manager.h"
35 #include "src/core/lib/promise/activity.h"
36 #include "src/core/lib/promise/arena_promise.h"
37 #include "src/core/lib/promise/context.h"
38 #include "src/core/lib/promise/pipe.h"
39 #include "src/core/lib/promise/poll.h"
40 #include "src/core/lib/promise/seq.h"
41 #include "src/core/lib/resource_quota/arena.h"
42 #include "src/core/lib/slice/slice.h"
43 #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h"
44 
45 using grpc_event_engine::experimental::FuzzingEventEngine;
46 using grpc_event_engine::experimental::GetDefaultEventEngine;
47 
48 namespace grpc_core {
49 
50 ///////////////////////////////////////////////////////////////////////////////
51 // FilterTestBase::Call::Impl
52 
53 class FilterTestBase::Call::Impl
54     : public std::enable_shared_from_this<FilterTestBase::Call::Impl> {
55  public:
Impl(Call * call,std::shared_ptr<Channel::Impl> channel)56   Impl(Call* call, std::shared_ptr<Channel::Impl> channel)
57       : call_(call), channel_(std::move(channel)) {}
58   ~Impl();
59 
arena()60   Arena* arena() { return arena_.get(); }
legacy_context()61   grpc_call_context_element* legacy_context() { return legacy_context_; }
channel() const62   const std::shared_ptr<Channel::Impl>& channel() const { return channel_; }
call_finalization()63   CallFinalization* call_finalization() { return &call_finalization_; }
64 
65   void Start(ClientMetadataHandle md);
66   void ForwardServerInitialMetadata(ServerMetadataHandle md);
67   void ForwardMessageClientToServer(MessageHandle msg);
68   void ForwardMessageServerToClient(MessageHandle msg);
69   void FinishNextFilter(ServerMetadataHandle md);
70 
71   void StepLoop();
72 
event_engine()73   grpc_event_engine::experimental::EventEngine* event_engine() {
74     return channel_->test->event_engine();
75   }
76 
events()77   Events& events() { return channel_->test->events; }
78 
79  private:
80   bool StepOnce();
81   Poll<ServerMetadataHandle> PollNextFilter();
82   void ForceWakeup();
83 
84   Call* const call_;
85   std::shared_ptr<Channel::Impl> const channel_;
86   ScopedArenaPtr arena_{MakeScopedArena(channel_->initial_arena_size,
87                                         &channel_->memory_allocator)};
88   bool run_call_finalization_ = false;
89   CallFinalization call_finalization_;
90   absl::optional<ArenaPromise<ServerMetadataHandle>> promise_;
91   Poll<ServerMetadataHandle> poll_next_filter_result_;
92   Pipe<ServerMetadataHandle> pipe_server_initial_metadata_{arena_.get()};
93   Pipe<MessageHandle> pipe_server_to_client_messages_{arena_.get()};
94   Pipe<MessageHandle> pipe_client_to_server_messages_{arena_.get()};
95   PipeSender<ServerMetadataHandle>* server_initial_metadata_sender_ = nullptr;
96   PipeSender<MessageHandle>* server_to_client_messages_sender_ = nullptr;
97   PipeReceiver<MessageHandle>* client_to_server_messages_receiver_ = nullptr;
98   absl::optional<PipeSender<ServerMetadataHandle>::PushType>
99       push_server_initial_metadata_;
100   absl::optional<PipeReceiverNextType<ServerMetadataHandle>>
101       next_server_initial_metadata_;
102   absl::optional<PipeSender<MessageHandle>::PushType>
103       push_server_to_client_messages_;
104   absl::optional<PipeReceiverNextType<MessageHandle>>
105       next_server_to_client_messages_;
106   absl::optional<PipeSender<MessageHandle>::PushType>
107       push_client_to_server_messages_;
108   absl::optional<PipeReceiverNextType<MessageHandle>>
109       next_client_to_server_messages_;
110   absl::optional<ServerMetadataHandle> forward_server_initial_metadata_;
111   std::queue<MessageHandle> forward_client_to_server_messages_;
112   std::queue<MessageHandle> forward_server_to_client_messages_;
113   // Contexts for various subsystems (security, tracing, ...).
114   grpc_call_context_element legacy_context_[GRPC_CONTEXT_COUNT] = {};
115 };
116 
~Impl()117 FilterTestBase::Call::Impl::~Impl() {
118   if (!run_call_finalization_) {
119     call_finalization_.Run(nullptr);
120   }
121   for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) {
122     if (legacy_context_[i].destroy != nullptr) {
123       legacy_context_[i].destroy(legacy_context_[i].value);
124     }
125   }
126 }
127 
Start(ClientMetadataHandle md)128 void FilterTestBase::Call::Impl::Start(ClientMetadataHandle md) {
129   EXPECT_EQ(promise_, absl::nullopt);
130   promise_ = channel_->filter->MakeCallPromise(
131       CallArgs{std::move(md), ClientInitialMetadataOutstandingToken::Empty(),
132                nullptr, &pipe_server_initial_metadata_.sender,
133                &pipe_client_to_server_messages_.receiver,
134                &pipe_server_to_client_messages_.sender},
135       [this](CallArgs args) -> ArenaPromise<ServerMetadataHandle> {
136         server_initial_metadata_sender_ = args.server_initial_metadata;
137         client_to_server_messages_receiver_ = args.client_to_server_messages;
138         server_to_client_messages_sender_ = args.server_to_client_messages;
139         next_server_initial_metadata_.emplace(
140             pipe_server_initial_metadata_.receiver.Next());
141         events().Started(call_, *args.client_initial_metadata);
142         return [this]() { return PollNextFilter(); };
143       });
144   EXPECT_NE(promise_, absl::nullopt);
145   ForceWakeup();
146 }
147 
PollNextFilter()148 Poll<ServerMetadataHandle> FilterTestBase::Call::Impl::PollNextFilter() {
149   return std::exchange(poll_next_filter_result_, Pending());
150 }
151 
ForwardServerInitialMetadata(ServerMetadataHandle md)152 void FilterTestBase::Call::Impl::ForwardServerInitialMetadata(
153     ServerMetadataHandle md) {
154   EXPECT_FALSE(forward_server_initial_metadata_.has_value());
155   forward_server_initial_metadata_ = std::move(md);
156   ForceWakeup();
157 }
158 
ForwardMessageClientToServer(MessageHandle msg)159 void FilterTestBase::Call::Impl::ForwardMessageClientToServer(
160     MessageHandle msg) {
161   forward_client_to_server_messages_.push(std::move(msg));
162   ForceWakeup();
163 }
164 
ForwardMessageServerToClient(MessageHandle msg)165 void FilterTestBase::Call::Impl::ForwardMessageServerToClient(
166     MessageHandle msg) {
167   forward_server_to_client_messages_.push(std::move(msg));
168   ForceWakeup();
169 }
170 
FinishNextFilter(ServerMetadataHandle md)171 void FilterTestBase::Call::Impl::FinishNextFilter(ServerMetadataHandle md) {
172   poll_next_filter_result_ = std::move(md);
173   ForceWakeup();
174 }
175 
StepOnce()176 bool FilterTestBase::Call::Impl::StepOnce() {
177   if (!promise_.has_value()) return true;
178 
179   if (forward_server_initial_metadata_.has_value() &&
180       !push_server_initial_metadata_.has_value()) {
181     push_server_initial_metadata_.emplace(server_initial_metadata_sender_->Push(
182         std::move(*forward_server_initial_metadata_)));
183     forward_server_initial_metadata_.reset();
184   }
185 
186   if (push_server_initial_metadata_.has_value()) {
187     auto r = (*push_server_initial_metadata_)();
188     if (r.ready()) push_server_initial_metadata_.reset();
189   }
190 
191   if (next_server_initial_metadata_.has_value()) {
192     auto r = (*next_server_initial_metadata_)();
193     if (auto* p = r.value_if_ready()) {
194       if (p->has_value()) {
195         events().ForwardedServerInitialMetadata(call_, *p->value());
196       }
197       next_server_initial_metadata_.reset();
198     }
199   }
200 
201   if (server_initial_metadata_sender_ != nullptr &&
202       !next_server_initial_metadata_.has_value()) {
203     // We've finished sending server initial metadata, so we can
204     // process server-to-client messages.
205     if (!next_server_to_client_messages_.has_value()) {
206       next_server_to_client_messages_.emplace(
207           pipe_server_to_client_messages_.receiver.Next());
208     }
209 
210     if (push_server_to_client_messages_.has_value()) {
211       auto r = (*push_server_to_client_messages_)();
212       if (r.ready()) push_server_to_client_messages_.reset();
213     }
214 
215     {
216       auto r = (*next_server_to_client_messages_)();
217       if (auto* p = r.value_if_ready()) {
218         if (p->has_value()) {
219           events().ForwardedMessageServerToClient(call_, *p->value());
220         }
221         next_server_to_client_messages_.reset();
222         GetContext<Activity>()->ForceImmediateRepoll();
223       }
224     }
225 
226     if (!push_server_to_client_messages_.has_value() &&
227         !forward_server_to_client_messages_.empty()) {
228       push_server_to_client_messages_.emplace(
229           server_to_client_messages_sender_->Push(
230               std::move(forward_server_to_client_messages_.front())));
231       forward_server_to_client_messages_.pop();
232       GetContext<Activity>()->ForceImmediateRepoll();
233     }
234   }
235 
236   if (client_to_server_messages_receiver_ != nullptr) {
237     if (!next_client_to_server_messages_.has_value()) {
238       next_client_to_server_messages_.emplace(
239           client_to_server_messages_receiver_->Next());
240     }
241 
242     if (push_client_to_server_messages_.has_value()) {
243       auto r = (*push_client_to_server_messages_)();
244       if (r.ready()) push_client_to_server_messages_.reset();
245     }
246 
247     {
248       auto r = (*next_client_to_server_messages_)();
249       if (auto* p = r.value_if_ready()) {
250         if (p->has_value()) {
251           events().ForwardedMessageClientToServer(call_, *p->value());
252         }
253         next_client_to_server_messages_.reset();
254         GetContext<Activity>()->ForceImmediateRepoll();
255       }
256     }
257 
258     if (!push_client_to_server_messages_.has_value() &&
259         !forward_client_to_server_messages_.empty()) {
260       push_client_to_server_messages_.emplace(
261           pipe_client_to_server_messages_.sender.Push(
262               std::move(forward_client_to_server_messages_.front())));
263       forward_client_to_server_messages_.pop();
264       GetContext<Activity>()->ForceImmediateRepoll();
265     }
266   }
267 
268   auto r = (*promise_)();
269   if (r.pending()) return false;
270   promise_.reset();
271   events().Finished(call_, *r.value());
272   return true;
273 }
274 
275 ///////////////////////////////////////////////////////////////////////////////
276 // FilterTestBase::Call::ScopedContext
277 
278 class FilterTestBase::Call::ScopedContext final
279     : public Activity,
280       public promise_detail::Context<Arena>,
281       public promise_detail::Context<grpc_call_context_element>,
282       public promise_detail::Context<CallFinalization> {
283  private:
284   class TestWakeable final : public Wakeable {
285    public:
TestWakeable(ScopedContext * ctx)286     explicit TestWakeable(ScopedContext* ctx)
287         : tag_(ctx->DebugTag()), impl_(ctx->impl_) {}
Wakeup(WakeupMask)288     void Wakeup(WakeupMask) override {
289       std::unique_ptr<TestWakeable> self(this);
290       auto impl = impl_.lock();
291       if (impl == nullptr) return;
292       impl->event_engine()->Run([weak_impl = impl_]() {
293         auto impl = weak_impl.lock();
294         if (impl != nullptr) impl->StepLoop();
295       });
296     }
WakeupAsync(WakeupMask)297     void WakeupAsync(WakeupMask) override { Wakeup(0); }
Drop(WakeupMask)298     void Drop(WakeupMask) override { delete this; }
ActivityDebugTag(WakeupMask) const299     std::string ActivityDebugTag(WakeupMask) const override { return tag_; }
300 
301    private:
302     const std::string tag_;
303     const std::weak_ptr<Impl> impl_;
304   };
305 
306  public:
ScopedContext(std::shared_ptr<Impl> impl)307   explicit ScopedContext(std::shared_ptr<Impl> impl)
308       : promise_detail::Context<Arena>(impl->arena()),
309         promise_detail::Context<grpc_call_context_element>(
310             impl->legacy_context()),
311         promise_detail::Context<CallFinalization>(impl->call_finalization()),
312         impl_(std::move(impl)) {}
313 
Orphan()314   void Orphan() override { Crash("Orphan called on Call::ScopedContext"); }
ForceImmediateRepoll(WakeupMask)315   void ForceImmediateRepoll(WakeupMask) override { repoll_ = true; }
MakeOwningWaker()316   Waker MakeOwningWaker() override { return Waker(new TestWakeable(this), 0); }
MakeNonOwningWaker()317   Waker MakeNonOwningWaker() override {
318     return Waker(new TestWakeable(this), 0);
319   }
DebugTag() const320   std::string DebugTag() const override {
321     return absl::StrFormat("FILTER_TEST_CALL[%p]", impl_.get());
322   }
323 
repoll() const324   bool repoll() const { return repoll_; }
325 
326  private:
327   ScopedActivity scoped_activity_{this};
328   const std::shared_ptr<Impl> impl_;
329   bool repoll_ = false;
330 };
331 
StepLoop()332 void FilterTestBase::Call::Impl::StepLoop() {
333   for (;;) {
334     ScopedContext ctx(shared_from_this());
335     if (!StepOnce() && ctx.repoll()) continue;
336     return;
337   }
338 }
339 
ForceWakeup()340 void FilterTestBase::Call::Impl::ForceWakeup() {
341   ScopedContext(shared_from_this()).MakeOwningWaker().Wakeup();
342 }
343 
344 ///////////////////////////////////////////////////////////////////////////////
345 // FilterTestBase::Call
346 
Call(const Channel & channel)347 FilterTestBase::Call::Call(const Channel& channel)
348     : impl_(std::make_unique<Impl>(this, channel.impl_)) {}
349 
~Call()350 FilterTestBase::Call::~Call() { ScopedContext x(std::move(impl_)); }
351 
arena()352 Arena* FilterTestBase::Call::arena() { return impl_->arena(); }
353 
NewClientMetadata(std::initializer_list<std::pair<absl::string_view,absl::string_view>> init)354 ClientMetadataHandle FilterTestBase::Call::NewClientMetadata(
355     std::initializer_list<std::pair<absl::string_view, absl::string_view>>
356         init) {
357   auto md = impl_->arena()->MakePooled<ClientMetadata>();
358   for (auto& p : init) {
359     auto parsed = ClientMetadata::Parse(
360         p.first, Slice::FromCopiedString(p.second), false,
361         p.first.length() + p.second.length() + 32,
362         [p](absl::string_view, const Slice&) {
363           Crash(absl::StrCat("Illegal metadata value: ", p.first, ": ",
364                              p.second));
365         });
366     md->Set(parsed);
367   }
368   return md;
369 }
370 
NewServerMetadata(std::initializer_list<std::pair<absl::string_view,absl::string_view>> init)371 ServerMetadataHandle FilterTestBase::Call::NewServerMetadata(
372     std::initializer_list<std::pair<absl::string_view, absl::string_view>>
373         init) {
374   auto md = impl_->arena()->MakePooled<ClientMetadata>();
375   for (auto& p : init) {
376     auto parsed = ServerMetadata::Parse(
377         p.first, Slice::FromCopiedString(p.second), false,
378         p.first.length() + p.second.length() + 32,
379         [p](absl::string_view, const Slice&) {
380           Crash(absl::StrCat("Illegal metadata value: ", p.first, ": ",
381                              p.second));
382         });
383     md->Set(parsed);
384   }
385   return md;
386 }
387 
NewMessage(absl::string_view payload,uint32_t flags)388 MessageHandle FilterTestBase::Call::NewMessage(absl::string_view payload,
389                                                uint32_t flags) {
390   SliceBuffer buffer;
391   if (!payload.empty()) buffer.Append(Slice::FromCopiedString(payload));
392   return impl_->arena()->MakePooled<Message>(std::move(buffer), flags);
393 }
394 
Start(ClientMetadataHandle md)395 void FilterTestBase::Call::Start(ClientMetadataHandle md) {
396   ScopedContext ctx(impl_);
397   impl_->Start(std::move(md));
398 }
399 
Cancel()400 void FilterTestBase::Call::Cancel() {
401   ScopedContext ctx(impl_);
402   impl_ = absl::make_unique<Impl>(this, impl_->channel());
403 }
404 
ForwardServerInitialMetadata(ServerMetadataHandle md)405 void FilterTestBase::Call::ForwardServerInitialMetadata(
406     ServerMetadataHandle md) {
407   impl_->ForwardServerInitialMetadata(std::move(md));
408 }
409 
ForwardMessageClientToServer(MessageHandle msg)410 void FilterTestBase::Call::ForwardMessageClientToServer(MessageHandle msg) {
411   impl_->ForwardMessageClientToServer(std::move(msg));
412 }
413 
ForwardMessageServerToClient(MessageHandle msg)414 void FilterTestBase::Call::ForwardMessageServerToClient(MessageHandle msg) {
415   impl_->ForwardMessageServerToClient(std::move(msg));
416 }
417 
FinishNextFilter(ServerMetadataHandle md)418 void FilterTestBase::Call::FinishNextFilter(ServerMetadataHandle md) {
419   impl_->FinishNextFilter(std::move(md));
420 }
421 
422 ///////////////////////////////////////////////////////////////////////////////
423 // FilterTestBase
424 
FilterTestBase()425 FilterTestBase::FilterTestBase() {
426   grpc_event_engine::experimental::SetEventEngineFactory([]() {
427     FuzzingEventEngine::Options options;
428     options.max_delay_run_after = std::chrono::milliseconds(500);
429     options.max_delay_write = std::chrono::milliseconds(50);
430     return std::make_unique<FuzzingEventEngine>(
431         options, fuzzing_event_engine::Actions());
432   });
433   event_engine_ =
434       std::dynamic_pointer_cast<FuzzingEventEngine>(GetDefaultEventEngine());
435   grpc_timer_manager_set_start_threaded(false);
436   grpc_init();
437 }
438 
~FilterTestBase()439 FilterTestBase::~FilterTestBase() {
440   grpc_shutdown();
441   event_engine_->UnsetGlobalHooks();
442 }
443 
Step()444 void FilterTestBase::Step() {
445   event_engine_->TickUntilIdle();
446   ::testing::Mock::VerifyAndClearExpectations(&events);
447 }
448 
449 }  // namespace grpc_core
450