• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H
16 #define GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H
17 
18 #include <grpc/support/port_platform.h>
19 
20 #include <memory>
21 #include <vector>
22 
23 #include "src/core/lib/transport/call_destination.h"
24 #include "src/core/lib/transport/call_filters.h"
25 #include "src/core/lib/transport/call_spine.h"
26 #include "src/core/lib/transport/metadata.h"
27 #include "src/core/util/ref_counted.h"
28 
29 namespace grpc_core {
30 
31 class Blackboard;
32 class InterceptionChainBuilder;
33 
34 // One hijacked call. Using this we can get access to the CallHandler for the
35 // call object above us, the processed metadata from any filters/interceptors
36 // above us, and also create new CallInterceptor objects that will be handled
37 // below.
38 class HijackedCall final {
39  public:
HijackedCall(ClientMetadataHandle metadata,RefCountedPtr<UnstartedCallDestination> destination,CallHandler call_handler)40   HijackedCall(ClientMetadataHandle metadata,
41                RefCountedPtr<UnstartedCallDestination> destination,
42                CallHandler call_handler)
43       : metadata_(std::move(metadata)),
44         destination_(std::move(destination)),
45         call_handler_(std::move(call_handler)) {}
46 
47   // Create a new call and pass it down the stack.
48   // This can be called as many times as needed.
49   CallInitiator MakeCall();
50   // Per MakeCall(), but precludes creating further calls.
51   // Allows us to optimize by not copying initial metadata.
MakeLastCall()52   CallInitiator MakeLastCall() {
53     return MakeCallWithMetadata(std::move(metadata_));
54   }
55 
original_call_handler()56   CallHandler& original_call_handler() { return call_handler_; }
57 
client_metadata()58   ClientMetadata& client_metadata() { return *metadata_; }
59 
60  private:
61   CallInitiator MakeCallWithMetadata(ClientMetadataHandle metadata);
62 
63   ClientMetadataHandle metadata_;
64   RefCountedPtr<UnstartedCallDestination> destination_;
65   CallHandler call_handler_;
66 };
67 
68 // A delegating UnstartedCallDestination for use as a hijacking filter.
69 //
70 // This class provides the final StartCall method, and delegates to the
71 // InterceptCall() method for the actual interception. It has the same semantics
72 // as StartCall, but affords the implementation the ability to prepare the
73 // UnstartedCallHandler appropriately.
74 //
75 // Implementations may look at the unprocessed initial metadata
76 // and decide to do one of three things:
77 //
78 // 1. It can hijack the call. Returns a HijackedCall object that can
79 //    be used to start new calls with the same metadata.
80 //
81 // 2. It can consume the call by calling `Consume`.
82 //
83 // 3. It can pass the call through to the next interceptor by calling
84 //    `PassThrough`.
85 //
86 // Upon the StartCall call the UnstartedCallHandler will be from the last
87 // *Interceptor* in the call chain (without having been processed by any
88 // intervening filters) -- note that this is commonly not useful (not enough
89 // guarantees), and so it's usually better to Hijack and examine the metadata.
90 
91 class Interceptor : public UnstartedCallDestination {
92  public:
93   using UnstartedCallDestination::UnstartedCallDestination;
94 
StartCall(UnstartedCallHandler unstarted_call_handler)95   void StartCall(UnstartedCallHandler unstarted_call_handler) final {
96     unstarted_call_handler.AddCallStack(filter_stack_);
97     InterceptCall(std::move(unstarted_call_handler));
98   }
99 
100  protected:
101   virtual void InterceptCall(UnstartedCallHandler unstarted_call_handler) = 0;
102 
103   // Returns a promise that resolves to a HijackedCall instance.
104   // Hijacking is the process of taking over a call and starting one or more new
105   // ones.
Hijack(UnstartedCallHandler unstarted_call_handler)106   auto Hijack(UnstartedCallHandler unstarted_call_handler) {
107     auto call_handler = unstarted_call_handler.StartCall();
108     return Map(call_handler.PullClientInitialMetadata(),
109                [call_handler, destination = wrapped_destination_](
110                    ValueOrFailure<ClientMetadataHandle> metadata) mutable
111                    -> ValueOrFailure<HijackedCall> {
112                  if (!metadata.ok()) return Failure{};
113                  return HijackedCall(std::move(metadata.value()),
114                                      std::move(destination),
115                                      std::move(call_handler));
116                });
117   }
118 
119   // Hijack a call with custom initial metadata.
120   // TODO(ctiller): Evaluate whether this or hijack or some other in-between
121   // API is what we need here (I think we need 2 or 3 more fully worked through
122   // samples) and then reduce this surface to one API.
123   CallInitiator MakeChildCall(ClientMetadataHandle metadata,
124                               RefCountedPtr<Arena> arena);
125 
126   // Consume this call - it will not be passed on to any further filters.
Consume(UnstartedCallHandler unstarted_call_handler)127   CallHandler Consume(UnstartedCallHandler unstarted_call_handler) {
128     return unstarted_call_handler.StartCall();
129   }
130 
131   // Pass through this call to the next filter.
PassThrough(UnstartedCallHandler unstarted_call_handler)132   void PassThrough(UnstartedCallHandler unstarted_call_handler) {
133     wrapped_destination_->StartCall(std::move(unstarted_call_handler));
134   }
135 
136  private:
137   friend class InterceptionChainBuilder;
138 
139   RefCountedPtr<UnstartedCallDestination> wrapped_destination_;
140   RefCountedPtr<CallFilters::Stack> filter_stack_;
141 };
142 
143 class InterceptionChainBuilder final {
144  public:
145   // The kind of destination that the chain will eventually call.
146   // We can bottom out in various types depending on where we're intercepting:
147   // - The top half of the client channel wants to terminate on a
148   //   UnstartedCallDestination (specifically the LB call destination).
149   // - The bottom half of the client channel and the server code wants to
150   //   terminate on a ClientTransport - which unlike a
151   //   UnstartedCallDestination demands a started CallHandler.
152   // There's some adaption code that's needed to start filters just prior
153   // to the bottoming out, and some design considerations to make with that.
154   // One way (that's not chosen here) would be to have the caller of the
155   // Builder provide something that can build an adaptor
156   // UnstartedCallDestination with parameters supplied by this builder - that
157   // disperses the responsibility of building the adaptor to the caller, which
158   // is not ideal - we might want to adjust the way this construct is built in
159   // the future, and building is a builder responsibility.
160   // Instead, we declare a relatively closed set of destinations here, and
161   // hide the adaptors inside the builder at build time.
162   using FinalDestination =
163       absl::variant<RefCountedPtr<UnstartedCallDestination>,
164                     RefCountedPtr<CallDestination>>;
165 
166   explicit InterceptionChainBuilder(ChannelArgs args,
167                                     const Blackboard* old_blackboard = nullptr,
168                                     Blackboard* new_blackboard = nullptr)
args_(std::move (args))169       : args_(std::move(args)),
170         old_blackboard_(old_blackboard),
171         new_blackboard_(new_blackboard) {}
172 
173   // Add a filter with a `Call` class as an inner member.
174   // Call class must be one compatible with the filters described in
175   // call_filters.h.
176   template <typename T>
177   absl::enable_if_t<sizeof(typename T::Call) != 0, InterceptionChainBuilder&>
Add()178   Add() {
179     if (!status_.ok()) return *this;
180     auto filter = T::Create(args_, {FilterInstanceId(FilterTypeId<T>()),
181                                     old_blackboard_, new_blackboard_});
182     if (!filter.ok()) {
183       status_ = filter.status();
184       return *this;
185     }
186     auto& sb = stack_builder();
187     sb.Add(filter.value().get());
188     sb.AddOwnedObject(std::move(filter.value()));
189     return *this;
190   };
191 
192   // Add a filter that is an interceptor - one that can hijack calls.
193   template <typename T>
194   absl::enable_if_t<std::is_base_of<Interceptor, T>::value,
195                     InterceptionChainBuilder&>
Add()196   Add() {
197     AddInterceptor(T::Create(args_, {FilterInstanceId(FilterTypeId<T>()),
198                                      old_blackboard_, new_blackboard_}));
199     return *this;
200   };
201 
202   // Add a filter that just mutates client initial metadata.
203   template <typename F>
AddOnClientInitialMetadata(F f)204   void AddOnClientInitialMetadata(F f) {
205     stack_builder().AddOnClientInitialMetadata(std::move(f));
206   }
207 
208   // Add a filter that just mutates server trailing metadata.
209   template <typename F>
AddOnServerTrailingMetadata(F f)210   void AddOnServerTrailingMetadata(F f) {
211     stack_builder().AddOnServerTrailingMetadata(std::move(f));
212   }
213 
Fail(absl::Status status)214   void Fail(absl::Status status) {
215     CHECK(!status.ok()) << status;
216     if (status_.ok()) status_ = std::move(status);
217   }
218 
219   // Build this stack
220   absl::StatusOr<RefCountedPtr<UnstartedCallDestination>> Build(
221       FinalDestination final_destination);
222 
channel_args()223   const ChannelArgs& channel_args() const { return args_; }
224 
225  private:
stack_builder()226   CallFilters::StackBuilder& stack_builder() {
227     if (!stack_builder_.has_value()) stack_builder_.emplace();
228     return *stack_builder_;
229   }
230 
MakeFilterStack()231   RefCountedPtr<CallFilters::Stack> MakeFilterStack() {
232     auto stack = stack_builder().Build();
233     stack_builder_.reset();
234     return stack;
235   }
236 
237   template <typename T>
FilterTypeId()238   static size_t FilterTypeId() {
239     static const size_t id =
240         next_filter_id_.fetch_add(1, std::memory_order_relaxed);
241     return id;
242   }
243 
FilterInstanceId(size_t filter_type)244   size_t FilterInstanceId(size_t filter_type) {
245     return filter_type_counts_[filter_type]++;
246   }
247 
248   void AddInterceptor(absl::StatusOr<RefCountedPtr<Interceptor>> interceptor);
249 
250   ChannelArgs args_;
251   absl::optional<CallFilters::StackBuilder> stack_builder_;
252   RefCountedPtr<Interceptor> top_interceptor_;
253   absl::Status status_;
254   std::map<size_t, size_t> filter_type_counts_;
255   static std::atomic<size_t> next_filter_id_;
256   const Blackboard* old_blackboard_ = nullptr;
257   Blackboard* new_blackboard_ = nullptr;
258 };
259 
260 }  // namespace grpc_core
261 
262 #endif  // GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H
263