1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/core/lib/transport/interception_chain.h"
16
17 #include <grpc/support/port_platform.h>
18
19 #include <cstddef>
20
21 #include "src/core/lib/debug/trace.h"
22 #include "src/core/lib/transport/call_destination.h"
23 #include "src/core/lib/transport/call_filters.h"
24 #include "src/core/lib/transport/call_spine.h"
25 #include "src/core/lib/transport/metadata.h"
26 #include "src/core/util/match.h"
27
28 namespace grpc_core {
29
30 std::atomic<size_t> InterceptionChainBuilder::next_filter_id_{0};
31
32 ///////////////////////////////////////////////////////////////////////////////
33 // HijackedCall
34
MakeCall()35 CallInitiator HijackedCall::MakeCall() {
36 auto metadata = Arena::MakePooledForOverwrite<ClientMetadata>();
37 *metadata = metadata_->Copy();
38 return MakeCallWithMetadata(std::move(metadata));
39 }
40
MakeCallWithMetadata(ClientMetadataHandle metadata)41 CallInitiator HijackedCall::MakeCallWithMetadata(
42 ClientMetadataHandle metadata) {
43 auto call = MakeCallPair(std::move(metadata), call_handler_.arena()->Ref());
44 destination_->StartCall(std::move(call.handler));
45 return std::move(call.initiator);
46 }
47
MakeChildCall(ClientMetadataHandle metadata,RefCountedPtr<Arena> arena)48 CallInitiator Interceptor::MakeChildCall(ClientMetadataHandle metadata,
49 RefCountedPtr<Arena> arena) {
50 auto call = MakeCallPair(std::move(metadata), arena);
51 wrapped_destination_->StartCall(std::move(call.handler));
52 return std::move(call.initiator);
53 }
54
55 namespace {
56 class CallStarter final : public UnstartedCallDestination {
57 public:
CallStarter(RefCountedPtr<CallFilters::Stack> stack,RefCountedPtr<CallDestination> destination)58 CallStarter(RefCountedPtr<CallFilters::Stack> stack,
59 RefCountedPtr<CallDestination> destination)
60 : stack_(std::move(stack)), destination_(std::move(destination)) {}
61
Orphaned()62 void Orphaned() override {
63 stack_.reset();
64 destination_.reset();
65 }
66
StartCall(UnstartedCallHandler unstarted_call_handler)67 void StartCall(UnstartedCallHandler unstarted_call_handler) override {
68 unstarted_call_handler.AddCallStack(stack_);
69 destination_->HandleCall(unstarted_call_handler.StartCall());
70 }
71
72 private:
73 RefCountedPtr<CallFilters::Stack> stack_;
74 RefCountedPtr<CallDestination> destination_;
75 };
76
77 class TerminalInterceptor final : public UnstartedCallDestination {
78 public:
TerminalInterceptor(RefCountedPtr<CallFilters::Stack> stack,RefCountedPtr<UnstartedCallDestination> destination)79 explicit TerminalInterceptor(
80 RefCountedPtr<CallFilters::Stack> stack,
81 RefCountedPtr<UnstartedCallDestination> destination)
82 : stack_(std::move(stack)), destination_(std::move(destination)) {}
83
Orphaned()84 void Orphaned() override {
85 stack_.reset();
86 destination_.reset();
87 }
88
StartCall(UnstartedCallHandler unstarted_call_handler)89 void StartCall(UnstartedCallHandler unstarted_call_handler) override {
90 unstarted_call_handler.AddCallStack(stack_);
91 destination_->StartCall(unstarted_call_handler);
92 }
93
94 private:
95 RefCountedPtr<CallFilters::Stack> stack_;
96 RefCountedPtr<UnstartedCallDestination> destination_;
97 };
98 } // namespace
99
100 ///////////////////////////////////////////////////////////////////////////////
101 // InterceptionChain::Builder
102
AddInterceptor(absl::StatusOr<RefCountedPtr<Interceptor>> interceptor)103 void InterceptionChainBuilder::AddInterceptor(
104 absl::StatusOr<RefCountedPtr<Interceptor>> interceptor) {
105 if (!status_.ok()) return;
106 if (!interceptor.ok()) {
107 status_ = interceptor.status();
108 return;
109 }
110 (*interceptor)->filter_stack_ = MakeFilterStack();
111 if (top_interceptor_ == nullptr) {
112 top_interceptor_ = std::move(*interceptor);
113 } else {
114 Interceptor* previous = top_interceptor_.get();
115 while (previous->wrapped_destination_ != nullptr) {
116 previous = DownCast<Interceptor*>(previous->wrapped_destination_.get());
117 }
118 previous->wrapped_destination_ = std::move(*interceptor);
119 }
120 }
121
122 absl::StatusOr<RefCountedPtr<UnstartedCallDestination>>
Build(FinalDestination final_destination)123 InterceptionChainBuilder::Build(FinalDestination final_destination) {
124 if (!status_.ok()) return status_;
125 // Build the final UnstartedCallDestination in the chain - what we do here
126 // depends on both the type of the final destination and the filters we have
127 // that haven't been captured into an Interceptor yet.
128 RefCountedPtr<UnstartedCallDestination> terminator = Match(
129 final_destination,
130 [this](RefCountedPtr<UnstartedCallDestination> final_destination)
131 -> RefCountedPtr<UnstartedCallDestination> {
132 if (stack_builder_.has_value()) {
133 return MakeRefCounted<TerminalInterceptor>(MakeFilterStack(),
134 final_destination);
135 }
136 return final_destination;
137 },
138 [this](RefCountedPtr<CallDestination> final_destination)
139 -> RefCountedPtr<UnstartedCallDestination> {
140 return MakeRefCounted<CallStarter>(MakeFilterStack(),
141 std::move(final_destination));
142 });
143 // Now append the terminator to the interceptor chain.
144 if (top_interceptor_ == nullptr) {
145 return std::move(terminator);
146 }
147 Interceptor* previous = top_interceptor_.get();
148 while (previous->wrapped_destination_ != nullptr) {
149 previous = DownCast<Interceptor*>(previous->wrapped_destination_.get());
150 }
151 previous->wrapped_destination_ = std::move(terminator);
152 return std::move(top_interceptor_);
153 }
154
155 } // namespace grpc_core
156