1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "test/core/filters/filter_test.h"
16
17 #include <algorithm>
18 #include <chrono>
19 #include <memory>
20 #include <queue>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/types/optional.h"
26 #include "gtest/gtest.h"
27
28 #include <grpc/grpc.h>
29
30 #include "src/core/lib/channel/call_finalization.h"
31 #include "src/core/lib/channel/context.h"
32 #include "src/core/lib/event_engine/default_event_engine.h"
33 #include "src/core/lib/gprpp/crash.h"
34 #include "src/core/lib/iomgr/timer_manager.h"
35 #include "src/core/lib/promise/activity.h"
36 #include "src/core/lib/promise/arena_promise.h"
37 #include "src/core/lib/promise/context.h"
38 #include "src/core/lib/promise/pipe.h"
39 #include "src/core/lib/promise/poll.h"
40 #include "src/core/lib/promise/seq.h"
41 #include "src/core/lib/resource_quota/arena.h"
42 #include "src/core/lib/slice/slice.h"
43 #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h"
44
45 using grpc_event_engine::experimental::FuzzingEventEngine;
46 using grpc_event_engine::experimental::GetDefaultEventEngine;
47
48 namespace grpc_core {
49
50 ///////////////////////////////////////////////////////////////////////////////
51 // FilterTestBase::Call::Impl
52
53 class FilterTestBase::Call::Impl
54 : public std::enable_shared_from_this<FilterTestBase::Call::Impl> {
55 public:
Impl(Call * call,std::shared_ptr<Channel::Impl> channel)56 Impl(Call* call, std::shared_ptr<Channel::Impl> channel)
57 : call_(call), channel_(std::move(channel)) {}
58 ~Impl();
59
arena()60 Arena* arena() { return arena_.get(); }
legacy_context()61 grpc_call_context_element* legacy_context() { return legacy_context_; }
channel() const62 const std::shared_ptr<Channel::Impl>& channel() const { return channel_; }
call_finalization()63 CallFinalization* call_finalization() { return &call_finalization_; }
64
65 void Start(ClientMetadataHandle md);
66 void ForwardServerInitialMetadata(ServerMetadataHandle md);
67 void ForwardMessageClientToServer(MessageHandle msg);
68 void ForwardMessageServerToClient(MessageHandle msg);
69 void FinishNextFilter(ServerMetadataHandle md);
70
71 void StepLoop();
72
event_engine()73 grpc_event_engine::experimental::EventEngine* event_engine() {
74 return channel_->test->event_engine();
75 }
76
events()77 Events& events() { return channel_->test->events; }
78
79 private:
80 bool StepOnce();
81 Poll<ServerMetadataHandle> PollNextFilter();
82 void ForceWakeup();
83
84 Call* const call_;
85 std::shared_ptr<Channel::Impl> const channel_;
86 ScopedArenaPtr arena_{MakeScopedArena(channel_->initial_arena_size,
87 &channel_->memory_allocator)};
88 bool run_call_finalization_ = false;
89 CallFinalization call_finalization_;
90 absl::optional<ArenaPromise<ServerMetadataHandle>> promise_;
91 Poll<ServerMetadataHandle> poll_next_filter_result_;
92 Pipe<ServerMetadataHandle> pipe_server_initial_metadata_{arena_.get()};
93 Pipe<MessageHandle> pipe_server_to_client_messages_{arena_.get()};
94 Pipe<MessageHandle> pipe_client_to_server_messages_{arena_.get()};
95 PipeSender<ServerMetadataHandle>* server_initial_metadata_sender_ = nullptr;
96 PipeSender<MessageHandle>* server_to_client_messages_sender_ = nullptr;
97 PipeReceiver<MessageHandle>* client_to_server_messages_receiver_ = nullptr;
98 absl::optional<PipeSender<ServerMetadataHandle>::PushType>
99 push_server_initial_metadata_;
100 absl::optional<PipeReceiverNextType<ServerMetadataHandle>>
101 next_server_initial_metadata_;
102 absl::optional<PipeSender<MessageHandle>::PushType>
103 push_server_to_client_messages_;
104 absl::optional<PipeReceiverNextType<MessageHandle>>
105 next_server_to_client_messages_;
106 absl::optional<PipeSender<MessageHandle>::PushType>
107 push_client_to_server_messages_;
108 absl::optional<PipeReceiverNextType<MessageHandle>>
109 next_client_to_server_messages_;
110 absl::optional<ServerMetadataHandle> forward_server_initial_metadata_;
111 std::queue<MessageHandle> forward_client_to_server_messages_;
112 std::queue<MessageHandle> forward_server_to_client_messages_;
113 // Contexts for various subsystems (security, tracing, ...).
114 grpc_call_context_element legacy_context_[GRPC_CONTEXT_COUNT] = {};
115 };
116
~Impl()117 FilterTestBase::Call::Impl::~Impl() {
118 if (!run_call_finalization_) {
119 call_finalization_.Run(nullptr);
120 }
121 for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) {
122 if (legacy_context_[i].destroy != nullptr) {
123 legacy_context_[i].destroy(legacy_context_[i].value);
124 }
125 }
126 }
127
Start(ClientMetadataHandle md)128 void FilterTestBase::Call::Impl::Start(ClientMetadataHandle md) {
129 EXPECT_EQ(promise_, absl::nullopt);
130 promise_ = channel_->filter->MakeCallPromise(
131 CallArgs{std::move(md), ClientInitialMetadataOutstandingToken::Empty(),
132 nullptr, &pipe_server_initial_metadata_.sender,
133 &pipe_client_to_server_messages_.receiver,
134 &pipe_server_to_client_messages_.sender},
135 [this](CallArgs args) -> ArenaPromise<ServerMetadataHandle> {
136 server_initial_metadata_sender_ = args.server_initial_metadata;
137 client_to_server_messages_receiver_ = args.client_to_server_messages;
138 server_to_client_messages_sender_ = args.server_to_client_messages;
139 next_server_initial_metadata_.emplace(
140 pipe_server_initial_metadata_.receiver.Next());
141 events().Started(call_, *args.client_initial_metadata);
142 return [this]() { return PollNextFilter(); };
143 });
144 EXPECT_NE(promise_, absl::nullopt);
145 ForceWakeup();
146 }
147
PollNextFilter()148 Poll<ServerMetadataHandle> FilterTestBase::Call::Impl::PollNextFilter() {
149 return std::exchange(poll_next_filter_result_, Pending());
150 }
151
ForwardServerInitialMetadata(ServerMetadataHandle md)152 void FilterTestBase::Call::Impl::ForwardServerInitialMetadata(
153 ServerMetadataHandle md) {
154 EXPECT_FALSE(forward_server_initial_metadata_.has_value());
155 forward_server_initial_metadata_ = std::move(md);
156 ForceWakeup();
157 }
158
ForwardMessageClientToServer(MessageHandle msg)159 void FilterTestBase::Call::Impl::ForwardMessageClientToServer(
160 MessageHandle msg) {
161 forward_client_to_server_messages_.push(std::move(msg));
162 ForceWakeup();
163 }
164
ForwardMessageServerToClient(MessageHandle msg)165 void FilterTestBase::Call::Impl::ForwardMessageServerToClient(
166 MessageHandle msg) {
167 forward_server_to_client_messages_.push(std::move(msg));
168 ForceWakeup();
169 }
170
FinishNextFilter(ServerMetadataHandle md)171 void FilterTestBase::Call::Impl::FinishNextFilter(ServerMetadataHandle md) {
172 poll_next_filter_result_ = std::move(md);
173 ForceWakeup();
174 }
175
StepOnce()176 bool FilterTestBase::Call::Impl::StepOnce() {
177 if (!promise_.has_value()) return true;
178
179 if (forward_server_initial_metadata_.has_value() &&
180 !push_server_initial_metadata_.has_value()) {
181 push_server_initial_metadata_.emplace(server_initial_metadata_sender_->Push(
182 std::move(*forward_server_initial_metadata_)));
183 forward_server_initial_metadata_.reset();
184 }
185
186 if (push_server_initial_metadata_.has_value()) {
187 auto r = (*push_server_initial_metadata_)();
188 if (r.ready()) push_server_initial_metadata_.reset();
189 }
190
191 if (next_server_initial_metadata_.has_value()) {
192 auto r = (*next_server_initial_metadata_)();
193 if (auto* p = r.value_if_ready()) {
194 if (p->has_value()) {
195 events().ForwardedServerInitialMetadata(call_, *p->value());
196 }
197 next_server_initial_metadata_.reset();
198 }
199 }
200
201 if (server_initial_metadata_sender_ != nullptr &&
202 !next_server_initial_metadata_.has_value()) {
203 // We've finished sending server initial metadata, so we can
204 // process server-to-client messages.
205 if (!next_server_to_client_messages_.has_value()) {
206 next_server_to_client_messages_.emplace(
207 pipe_server_to_client_messages_.receiver.Next());
208 }
209
210 if (push_server_to_client_messages_.has_value()) {
211 auto r = (*push_server_to_client_messages_)();
212 if (r.ready()) push_server_to_client_messages_.reset();
213 }
214
215 {
216 auto r = (*next_server_to_client_messages_)();
217 if (auto* p = r.value_if_ready()) {
218 if (p->has_value()) {
219 events().ForwardedMessageServerToClient(call_, *p->value());
220 }
221 next_server_to_client_messages_.reset();
222 GetContext<Activity>()->ForceImmediateRepoll();
223 }
224 }
225
226 if (!push_server_to_client_messages_.has_value() &&
227 !forward_server_to_client_messages_.empty()) {
228 push_server_to_client_messages_.emplace(
229 server_to_client_messages_sender_->Push(
230 std::move(forward_server_to_client_messages_.front())));
231 forward_server_to_client_messages_.pop();
232 GetContext<Activity>()->ForceImmediateRepoll();
233 }
234 }
235
236 if (client_to_server_messages_receiver_ != nullptr) {
237 if (!next_client_to_server_messages_.has_value()) {
238 next_client_to_server_messages_.emplace(
239 client_to_server_messages_receiver_->Next());
240 }
241
242 if (push_client_to_server_messages_.has_value()) {
243 auto r = (*push_client_to_server_messages_)();
244 if (r.ready()) push_client_to_server_messages_.reset();
245 }
246
247 {
248 auto r = (*next_client_to_server_messages_)();
249 if (auto* p = r.value_if_ready()) {
250 if (p->has_value()) {
251 events().ForwardedMessageClientToServer(call_, *p->value());
252 }
253 next_client_to_server_messages_.reset();
254 GetContext<Activity>()->ForceImmediateRepoll();
255 }
256 }
257
258 if (!push_client_to_server_messages_.has_value() &&
259 !forward_client_to_server_messages_.empty()) {
260 push_client_to_server_messages_.emplace(
261 pipe_client_to_server_messages_.sender.Push(
262 std::move(forward_client_to_server_messages_.front())));
263 forward_client_to_server_messages_.pop();
264 GetContext<Activity>()->ForceImmediateRepoll();
265 }
266 }
267
268 auto r = (*promise_)();
269 if (r.pending()) return false;
270 promise_.reset();
271 events().Finished(call_, *r.value());
272 return true;
273 }
274
275 ///////////////////////////////////////////////////////////////////////////////
276 // FilterTestBase::Call::ScopedContext
277
278 class FilterTestBase::Call::ScopedContext final
279 : public Activity,
280 public promise_detail::Context<Arena>,
281 public promise_detail::Context<grpc_call_context_element>,
282 public promise_detail::Context<CallFinalization> {
283 private:
284 class TestWakeable final : public Wakeable {
285 public:
TestWakeable(ScopedContext * ctx)286 explicit TestWakeable(ScopedContext* ctx)
287 : tag_(ctx->DebugTag()), impl_(ctx->impl_) {}
Wakeup(WakeupMask)288 void Wakeup(WakeupMask) override {
289 std::unique_ptr<TestWakeable> self(this);
290 auto impl = impl_.lock();
291 if (impl == nullptr) return;
292 impl->event_engine()->Run([weak_impl = impl_]() {
293 auto impl = weak_impl.lock();
294 if (impl != nullptr) impl->StepLoop();
295 });
296 }
WakeupAsync(WakeupMask)297 void WakeupAsync(WakeupMask) override { Wakeup(0); }
Drop(WakeupMask)298 void Drop(WakeupMask) override { delete this; }
ActivityDebugTag(WakeupMask) const299 std::string ActivityDebugTag(WakeupMask) const override { return tag_; }
300
301 private:
302 const std::string tag_;
303 const std::weak_ptr<Impl> impl_;
304 };
305
306 public:
ScopedContext(std::shared_ptr<Impl> impl)307 explicit ScopedContext(std::shared_ptr<Impl> impl)
308 : promise_detail::Context<Arena>(impl->arena()),
309 promise_detail::Context<grpc_call_context_element>(
310 impl->legacy_context()),
311 promise_detail::Context<CallFinalization>(impl->call_finalization()),
312 impl_(std::move(impl)) {}
313
Orphan()314 void Orphan() override { Crash("Orphan called on Call::ScopedContext"); }
ForceImmediateRepoll(WakeupMask)315 void ForceImmediateRepoll(WakeupMask) override { repoll_ = true; }
MakeOwningWaker()316 Waker MakeOwningWaker() override { return Waker(new TestWakeable(this), 0); }
MakeNonOwningWaker()317 Waker MakeNonOwningWaker() override {
318 return Waker(new TestWakeable(this), 0);
319 }
DebugTag() const320 std::string DebugTag() const override {
321 return absl::StrFormat("FILTER_TEST_CALL[%p]", impl_.get());
322 }
323
repoll() const324 bool repoll() const { return repoll_; }
325
326 private:
327 ScopedActivity scoped_activity_{this};
328 const std::shared_ptr<Impl> impl_;
329 bool repoll_ = false;
330 };
331
StepLoop()332 void FilterTestBase::Call::Impl::StepLoop() {
333 for (;;) {
334 ScopedContext ctx(shared_from_this());
335 if (!StepOnce() && ctx.repoll()) continue;
336 return;
337 }
338 }
339
ForceWakeup()340 void FilterTestBase::Call::Impl::ForceWakeup() {
341 ScopedContext(shared_from_this()).MakeOwningWaker().Wakeup();
342 }
343
344 ///////////////////////////////////////////////////////////////////////////////
345 // FilterTestBase::Call
346
Call(const Channel & channel)347 FilterTestBase::Call::Call(const Channel& channel)
348 : impl_(std::make_unique<Impl>(this, channel.impl_)) {}
349
~Call()350 FilterTestBase::Call::~Call() { ScopedContext x(std::move(impl_)); }
351
arena()352 Arena* FilterTestBase::Call::arena() { return impl_->arena(); }
353
NewClientMetadata(std::initializer_list<std::pair<absl::string_view,absl::string_view>> init)354 ClientMetadataHandle FilterTestBase::Call::NewClientMetadata(
355 std::initializer_list<std::pair<absl::string_view, absl::string_view>>
356 init) {
357 auto md = impl_->arena()->MakePooled<ClientMetadata>();
358 for (auto& p : init) {
359 auto parsed = ClientMetadata::Parse(
360 p.first, Slice::FromCopiedString(p.second), false,
361 p.first.length() + p.second.length() + 32,
362 [p](absl::string_view, const Slice&) {
363 Crash(absl::StrCat("Illegal metadata value: ", p.first, ": ",
364 p.second));
365 });
366 md->Set(parsed);
367 }
368 return md;
369 }
370
NewServerMetadata(std::initializer_list<std::pair<absl::string_view,absl::string_view>> init)371 ServerMetadataHandle FilterTestBase::Call::NewServerMetadata(
372 std::initializer_list<std::pair<absl::string_view, absl::string_view>>
373 init) {
374 auto md = impl_->arena()->MakePooled<ClientMetadata>();
375 for (auto& p : init) {
376 auto parsed = ServerMetadata::Parse(
377 p.first, Slice::FromCopiedString(p.second), false,
378 p.first.length() + p.second.length() + 32,
379 [p](absl::string_view, const Slice&) {
380 Crash(absl::StrCat("Illegal metadata value: ", p.first, ": ",
381 p.second));
382 });
383 md->Set(parsed);
384 }
385 return md;
386 }
387
NewMessage(absl::string_view payload,uint32_t flags)388 MessageHandle FilterTestBase::Call::NewMessage(absl::string_view payload,
389 uint32_t flags) {
390 SliceBuffer buffer;
391 if (!payload.empty()) buffer.Append(Slice::FromCopiedString(payload));
392 return impl_->arena()->MakePooled<Message>(std::move(buffer), flags);
393 }
394
Start(ClientMetadataHandle md)395 void FilterTestBase::Call::Start(ClientMetadataHandle md) {
396 ScopedContext ctx(impl_);
397 impl_->Start(std::move(md));
398 }
399
Cancel()400 void FilterTestBase::Call::Cancel() {
401 ScopedContext ctx(impl_);
402 impl_ = absl::make_unique<Impl>(this, impl_->channel());
403 }
404
ForwardServerInitialMetadata(ServerMetadataHandle md)405 void FilterTestBase::Call::ForwardServerInitialMetadata(
406 ServerMetadataHandle md) {
407 impl_->ForwardServerInitialMetadata(std::move(md));
408 }
409
ForwardMessageClientToServer(MessageHandle msg)410 void FilterTestBase::Call::ForwardMessageClientToServer(MessageHandle msg) {
411 impl_->ForwardMessageClientToServer(std::move(msg));
412 }
413
ForwardMessageServerToClient(MessageHandle msg)414 void FilterTestBase::Call::ForwardMessageServerToClient(MessageHandle msg) {
415 impl_->ForwardMessageServerToClient(std::move(msg));
416 }
417
FinishNextFilter(ServerMetadataHandle md)418 void FilterTestBase::Call::FinishNextFilter(ServerMetadataHandle md) {
419 impl_->FinishNextFilter(std::move(md));
420 }
421
422 ///////////////////////////////////////////////////////////////////////////////
423 // FilterTestBase
424
FilterTestBase()425 FilterTestBase::FilterTestBase() {
426 grpc_event_engine::experimental::SetEventEngineFactory([]() {
427 FuzzingEventEngine::Options options;
428 options.max_delay_run_after = std::chrono::milliseconds(500);
429 options.max_delay_write = std::chrono::milliseconds(50);
430 return std::make_unique<FuzzingEventEngine>(
431 options, fuzzing_event_engine::Actions());
432 });
433 event_engine_ =
434 std::dynamic_pointer_cast<FuzzingEventEngine>(GetDefaultEventEngine());
435 grpc_timer_manager_set_start_threaded(false);
436 grpc_init();
437 }
438
~FilterTestBase()439 FilterTestBase::~FilterTestBase() {
440 grpc_shutdown();
441 event_engine_->UnsetGlobalHooks();
442 }
443
Step()444 void FilterTestBase::Step() {
445 event_engine_->TickUntilIdle();
446 ::testing::Mock::VerifyAndClearExpectations(&events);
447 }
448
449 } // namespace grpc_core
450