1 // Copyright 2024 gRPC authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H 16 #define GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H 17 18 #include <grpc/support/port_platform.h> 19 20 #include <memory> 21 #include <vector> 22 23 #include "src/core/lib/transport/call_destination.h" 24 #include "src/core/lib/transport/call_filters.h" 25 #include "src/core/lib/transport/call_spine.h" 26 #include "src/core/lib/transport/metadata.h" 27 #include "src/core/util/ref_counted.h" 28 29 namespace grpc_core { 30 31 class Blackboard; 32 class InterceptionChainBuilder; 33 34 // One hijacked call. Using this we can get access to the CallHandler for the 35 // call object above us, the processed metadata from any filters/interceptors 36 // above us, and also create new CallInterceptor objects that will be handled 37 // below. 38 class HijackedCall final { 39 public: HijackedCall(ClientMetadataHandle metadata,RefCountedPtr<UnstartedCallDestination> destination,CallHandler call_handler)40 HijackedCall(ClientMetadataHandle metadata, 41 RefCountedPtr<UnstartedCallDestination> destination, 42 CallHandler call_handler) 43 : metadata_(std::move(metadata)), 44 destination_(std::move(destination)), 45 call_handler_(std::move(call_handler)) {} 46 47 // Create a new call and pass it down the stack. 48 // This can be called as many times as needed. 49 CallInitiator MakeCall(); 50 // Per MakeCall(), but precludes creating further calls. 51 // Allows us to optimize by not copying initial metadata. MakeLastCall()52 CallInitiator MakeLastCall() { 53 return MakeCallWithMetadata(std::move(metadata_)); 54 } 55 original_call_handler()56 CallHandler& original_call_handler() { return call_handler_; } 57 client_metadata()58 ClientMetadata& client_metadata() { return *metadata_; } 59 60 private: 61 CallInitiator MakeCallWithMetadata(ClientMetadataHandle metadata); 62 63 ClientMetadataHandle metadata_; 64 RefCountedPtr<UnstartedCallDestination> destination_; 65 CallHandler call_handler_; 66 }; 67 68 // A delegating UnstartedCallDestination for use as a hijacking filter. 69 // 70 // This class provides the final StartCall method, and delegates to the 71 // InterceptCall() method for the actual interception. It has the same semantics 72 // as StartCall, but affords the implementation the ability to prepare the 73 // UnstartedCallHandler appropriately. 74 // 75 // Implementations may look at the unprocessed initial metadata 76 // and decide to do one of three things: 77 // 78 // 1. It can hijack the call. Returns a HijackedCall object that can 79 // be used to start new calls with the same metadata. 80 // 81 // 2. It can consume the call by calling `Consume`. 82 // 83 // 3. It can pass the call through to the next interceptor by calling 84 // `PassThrough`. 85 // 86 // Upon the StartCall call the UnstartedCallHandler will be from the last 87 // *Interceptor* in the call chain (without having been processed by any 88 // intervening filters) -- note that this is commonly not useful (not enough 89 // guarantees), and so it's usually better to Hijack and examine the metadata. 90 91 class Interceptor : public UnstartedCallDestination { 92 public: 93 using UnstartedCallDestination::UnstartedCallDestination; 94 StartCall(UnstartedCallHandler unstarted_call_handler)95 void StartCall(UnstartedCallHandler unstarted_call_handler) final { 96 unstarted_call_handler.AddCallStack(filter_stack_); 97 InterceptCall(std::move(unstarted_call_handler)); 98 } 99 100 protected: 101 virtual void InterceptCall(UnstartedCallHandler unstarted_call_handler) = 0; 102 103 // Returns a promise that resolves to a HijackedCall instance. 104 // Hijacking is the process of taking over a call and starting one or more new 105 // ones. Hijack(UnstartedCallHandler unstarted_call_handler)106 auto Hijack(UnstartedCallHandler unstarted_call_handler) { 107 auto call_handler = unstarted_call_handler.StartCall(); 108 return Map(call_handler.PullClientInitialMetadata(), 109 [call_handler, destination = wrapped_destination_]( 110 ValueOrFailure<ClientMetadataHandle> metadata) mutable 111 -> ValueOrFailure<HijackedCall> { 112 if (!metadata.ok()) return Failure{}; 113 return HijackedCall(std::move(metadata.value()), 114 std::move(destination), 115 std::move(call_handler)); 116 }); 117 } 118 119 // Hijack a call with custom initial metadata. 120 // TODO(ctiller): Evaluate whether this or hijack or some other in-between 121 // API is what we need here (I think we need 2 or 3 more fully worked through 122 // samples) and then reduce this surface to one API. 123 CallInitiator MakeChildCall(ClientMetadataHandle metadata, 124 RefCountedPtr<Arena> arena); 125 126 // Consume this call - it will not be passed on to any further filters. Consume(UnstartedCallHandler unstarted_call_handler)127 CallHandler Consume(UnstartedCallHandler unstarted_call_handler) { 128 return unstarted_call_handler.StartCall(); 129 } 130 131 // Pass through this call to the next filter. PassThrough(UnstartedCallHandler unstarted_call_handler)132 void PassThrough(UnstartedCallHandler unstarted_call_handler) { 133 wrapped_destination_->StartCall(std::move(unstarted_call_handler)); 134 } 135 136 private: 137 friend class InterceptionChainBuilder; 138 139 RefCountedPtr<UnstartedCallDestination> wrapped_destination_; 140 RefCountedPtr<CallFilters::Stack> filter_stack_; 141 }; 142 143 class InterceptionChainBuilder final { 144 public: 145 // The kind of destination that the chain will eventually call. 146 // We can bottom out in various types depending on where we're intercepting: 147 // - The top half of the client channel wants to terminate on a 148 // UnstartedCallDestination (specifically the LB call destination). 149 // - The bottom half of the client channel and the server code wants to 150 // terminate on a ClientTransport - which unlike a 151 // UnstartedCallDestination demands a started CallHandler. 152 // There's some adaption code that's needed to start filters just prior 153 // to the bottoming out, and some design considerations to make with that. 154 // One way (that's not chosen here) would be to have the caller of the 155 // Builder provide something that can build an adaptor 156 // UnstartedCallDestination with parameters supplied by this builder - that 157 // disperses the responsibility of building the adaptor to the caller, which 158 // is not ideal - we might want to adjust the way this construct is built in 159 // the future, and building is a builder responsibility. 160 // Instead, we declare a relatively closed set of destinations here, and 161 // hide the adaptors inside the builder at build time. 162 using FinalDestination = 163 absl::variant<RefCountedPtr<UnstartedCallDestination>, 164 RefCountedPtr<CallDestination>>; 165 166 explicit InterceptionChainBuilder(ChannelArgs args, 167 const Blackboard* old_blackboard = nullptr, 168 Blackboard* new_blackboard = nullptr) args_(std::move (args))169 : args_(std::move(args)), 170 old_blackboard_(old_blackboard), 171 new_blackboard_(new_blackboard) {} 172 173 // Add a filter with a `Call` class as an inner member. 174 // Call class must be one compatible with the filters described in 175 // call_filters.h. 176 template <typename T> 177 absl::enable_if_t<sizeof(typename T::Call) != 0, InterceptionChainBuilder&> Add()178 Add() { 179 if (!status_.ok()) return *this; 180 auto filter = T::Create(args_, {FilterInstanceId(FilterTypeId<T>()), 181 old_blackboard_, new_blackboard_}); 182 if (!filter.ok()) { 183 status_ = filter.status(); 184 return *this; 185 } 186 auto& sb = stack_builder(); 187 sb.Add(filter.value().get()); 188 sb.AddOwnedObject(std::move(filter.value())); 189 return *this; 190 }; 191 192 // Add a filter that is an interceptor - one that can hijack calls. 193 template <typename T> 194 absl::enable_if_t<std::is_base_of<Interceptor, T>::value, 195 InterceptionChainBuilder&> Add()196 Add() { 197 AddInterceptor(T::Create(args_, {FilterInstanceId(FilterTypeId<T>()), 198 old_blackboard_, new_blackboard_})); 199 return *this; 200 }; 201 202 // Add a filter that just mutates client initial metadata. 203 template <typename F> AddOnClientInitialMetadata(F f)204 void AddOnClientInitialMetadata(F f) { 205 stack_builder().AddOnClientInitialMetadata(std::move(f)); 206 } 207 208 // Add a filter that just mutates server trailing metadata. 209 template <typename F> AddOnServerTrailingMetadata(F f)210 void AddOnServerTrailingMetadata(F f) { 211 stack_builder().AddOnServerTrailingMetadata(std::move(f)); 212 } 213 Fail(absl::Status status)214 void Fail(absl::Status status) { 215 CHECK(!status.ok()) << status; 216 if (status_.ok()) status_ = std::move(status); 217 } 218 219 // Build this stack 220 absl::StatusOr<RefCountedPtr<UnstartedCallDestination>> Build( 221 FinalDestination final_destination); 222 channel_args()223 const ChannelArgs& channel_args() const { return args_; } 224 225 private: stack_builder()226 CallFilters::StackBuilder& stack_builder() { 227 if (!stack_builder_.has_value()) stack_builder_.emplace(); 228 return *stack_builder_; 229 } 230 MakeFilterStack()231 RefCountedPtr<CallFilters::Stack> MakeFilterStack() { 232 auto stack = stack_builder().Build(); 233 stack_builder_.reset(); 234 return stack; 235 } 236 237 template <typename T> FilterTypeId()238 static size_t FilterTypeId() { 239 static const size_t id = 240 next_filter_id_.fetch_add(1, std::memory_order_relaxed); 241 return id; 242 } 243 FilterInstanceId(size_t filter_type)244 size_t FilterInstanceId(size_t filter_type) { 245 return filter_type_counts_[filter_type]++; 246 } 247 248 void AddInterceptor(absl::StatusOr<RefCountedPtr<Interceptor>> interceptor); 249 250 ChannelArgs args_; 251 absl::optional<CallFilters::StackBuilder> stack_builder_; 252 RefCountedPtr<Interceptor> top_interceptor_; 253 absl::Status status_; 254 std::map<size_t, size_t> filter_type_counts_; 255 static std::atomic<size_t> next_filter_id_; 256 const Blackboard* old_blackboard_ = nullptr; 257 Blackboard* new_blackboard_ = nullptr; 258 }; 259 260 } // namespace grpc_core 261 262 #endif // GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H 263