• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/core/lib/transport/interception_chain.h"
16 
17 #include <grpc/support/port_platform.h>
18 
19 #include <cstddef>
20 
21 #include "src/core/lib/debug/trace.h"
22 #include "src/core/lib/transport/call_destination.h"
23 #include "src/core/lib/transport/call_filters.h"
24 #include "src/core/lib/transport/call_spine.h"
25 #include "src/core/lib/transport/metadata.h"
26 #include "src/core/util/match.h"
27 
28 namespace grpc_core {
29 
30 std::atomic<size_t> InterceptionChainBuilder::next_filter_id_{0};
31 
32 ///////////////////////////////////////////////////////////////////////////////
33 // HijackedCall
34 
MakeCall()35 CallInitiator HijackedCall::MakeCall() {
36   auto metadata = Arena::MakePooledForOverwrite<ClientMetadata>();
37   *metadata = metadata_->Copy();
38   return MakeCallWithMetadata(std::move(metadata));
39 }
40 
MakeCallWithMetadata(ClientMetadataHandle metadata)41 CallInitiator HijackedCall::MakeCallWithMetadata(
42     ClientMetadataHandle metadata) {
43   auto call = MakeCallPair(std::move(metadata), call_handler_.arena()->Ref());
44   destination_->StartCall(std::move(call.handler));
45   return std::move(call.initiator);
46 }
47 
MakeChildCall(ClientMetadataHandle metadata,RefCountedPtr<Arena> arena)48 CallInitiator Interceptor::MakeChildCall(ClientMetadataHandle metadata,
49                                          RefCountedPtr<Arena> arena) {
50   auto call = MakeCallPair(std::move(metadata), arena);
51   wrapped_destination_->StartCall(std::move(call.handler));
52   return std::move(call.initiator);
53 }
54 
55 namespace {
56 class CallStarter final : public UnstartedCallDestination {
57  public:
CallStarter(RefCountedPtr<CallFilters::Stack> stack,RefCountedPtr<CallDestination> destination)58   CallStarter(RefCountedPtr<CallFilters::Stack> stack,
59               RefCountedPtr<CallDestination> destination)
60       : stack_(std::move(stack)), destination_(std::move(destination)) {}
61 
Orphaned()62   void Orphaned() override {
63     stack_.reset();
64     destination_.reset();
65   }
66 
StartCall(UnstartedCallHandler unstarted_call_handler)67   void StartCall(UnstartedCallHandler unstarted_call_handler) override {
68     unstarted_call_handler.AddCallStack(stack_);
69     destination_->HandleCall(unstarted_call_handler.StartCall());
70   }
71 
72  private:
73   RefCountedPtr<CallFilters::Stack> stack_;
74   RefCountedPtr<CallDestination> destination_;
75 };
76 
77 class TerminalInterceptor final : public UnstartedCallDestination {
78  public:
TerminalInterceptor(RefCountedPtr<CallFilters::Stack> stack,RefCountedPtr<UnstartedCallDestination> destination)79   explicit TerminalInterceptor(
80       RefCountedPtr<CallFilters::Stack> stack,
81       RefCountedPtr<UnstartedCallDestination> destination)
82       : stack_(std::move(stack)), destination_(std::move(destination)) {}
83 
Orphaned()84   void Orphaned() override {
85     stack_.reset();
86     destination_.reset();
87   }
88 
StartCall(UnstartedCallHandler unstarted_call_handler)89   void StartCall(UnstartedCallHandler unstarted_call_handler) override {
90     unstarted_call_handler.AddCallStack(stack_);
91     destination_->StartCall(unstarted_call_handler);
92   }
93 
94  private:
95   RefCountedPtr<CallFilters::Stack> stack_;
96   RefCountedPtr<UnstartedCallDestination> destination_;
97 };
98 }  // namespace
99 
100 ///////////////////////////////////////////////////////////////////////////////
101 // InterceptionChain::Builder
102 
AddInterceptor(absl::StatusOr<RefCountedPtr<Interceptor>> interceptor)103 void InterceptionChainBuilder::AddInterceptor(
104     absl::StatusOr<RefCountedPtr<Interceptor>> interceptor) {
105   if (!status_.ok()) return;
106   if (!interceptor.ok()) {
107     status_ = interceptor.status();
108     return;
109   }
110   (*interceptor)->filter_stack_ = MakeFilterStack();
111   if (top_interceptor_ == nullptr) {
112     top_interceptor_ = std::move(*interceptor);
113   } else {
114     Interceptor* previous = top_interceptor_.get();
115     while (previous->wrapped_destination_ != nullptr) {
116       previous = DownCast<Interceptor*>(previous->wrapped_destination_.get());
117     }
118     previous->wrapped_destination_ = std::move(*interceptor);
119   }
120 }
121 
122 absl::StatusOr<RefCountedPtr<UnstartedCallDestination>>
Build(FinalDestination final_destination)123 InterceptionChainBuilder::Build(FinalDestination final_destination) {
124   if (!status_.ok()) return status_;
125   // Build the final UnstartedCallDestination in the chain - what we do here
126   // depends on both the type of the final destination and the filters we have
127   // that haven't been captured into an Interceptor yet.
128   RefCountedPtr<UnstartedCallDestination> terminator = Match(
129       final_destination,
130       [this](RefCountedPtr<UnstartedCallDestination> final_destination)
131           -> RefCountedPtr<UnstartedCallDestination> {
132         if (stack_builder_.has_value()) {
133           return MakeRefCounted<TerminalInterceptor>(MakeFilterStack(),
134                                                      final_destination);
135         }
136         return final_destination;
137       },
138       [this](RefCountedPtr<CallDestination> final_destination)
139           -> RefCountedPtr<UnstartedCallDestination> {
140         return MakeRefCounted<CallStarter>(MakeFilterStack(),
141                                            std::move(final_destination));
142       });
143   // Now append the terminator to the interceptor chain.
144   if (top_interceptor_ == nullptr) {
145     return std::move(terminator);
146   }
147   Interceptor* previous = top_interceptor_.get();
148   while (previous->wrapped_destination_ != nullptr) {
149     previous = DownCast<Interceptor*>(previous->wrapped_destination_.get());
150   }
151   previous->wrapped_destination_ = std::move(terminator);
152   return std::move(top_interceptor_);
153 }
154 
155 }  // namespace grpc_core
156