• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include "test/core/util/test_lb_policies.h"
20 
21 #include <string>
22 
23 #include <grpc/support/log.h>
24 
25 #include "src/core/ext/filters/client_channel/lb_policy.h"
26 #include "src/core/ext/filters/client_channel/lb_policy_registry.h"
27 #include "src/core/lib/channel/channel_args.h"
28 #include "src/core/lib/channel/channelz.h"
29 #include "src/core/lib/debug/trace.h"
30 #include "src/core/lib/gprpp/memory.h"
31 #include "src/core/lib/gprpp/orphanable.h"
32 #include "src/core/lib/gprpp/ref_counted_ptr.h"
33 #include "src/core/lib/iomgr/closure.h"
34 #include "src/core/lib/iomgr/combiner.h"
35 #include "src/core/lib/iomgr/error.h"
36 #include "src/core/lib/iomgr/pollset_set.h"
37 #include "src/core/lib/json/json.h"
38 #include "src/core/lib/transport/connectivity_state.h"
39 
40 namespace grpc_core {
41 
42 namespace {
43 
44 //
45 // ForwardingLoadBalancingPolicy
46 //
47 
48 // A minimal forwarding class to avoid implementing a standalone test LB.
49 class ForwardingLoadBalancingPolicy : public LoadBalancingPolicy {
50  public:
ForwardingLoadBalancingPolicy(std::unique_ptr<ChannelControlHelper> delegating_helper,Args args,const std::string & delegate_policy_name,intptr_t initial_refcount=1)51   ForwardingLoadBalancingPolicy(
52       std::unique_ptr<ChannelControlHelper> delegating_helper, Args args,
53       const std::string& delegate_policy_name, intptr_t initial_refcount = 1)
54       : LoadBalancingPolicy(std::move(args), initial_refcount) {
55     Args delegate_args;
56     delegate_args.work_serializer = work_serializer();
57     delegate_args.channel_control_helper = std::move(delegating_helper);
58     delegate_args.args = args.args;
59     delegate_ = LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy(
60         delegate_policy_name.c_str(), std::move(delegate_args));
61     grpc_pollset_set_add_pollset_set(delegate_->interested_parties(),
62                                      interested_parties());
63   }
64 
65   ~ForwardingLoadBalancingPolicy() override = default;
66 
UpdateLocked(UpdateArgs args)67   void UpdateLocked(UpdateArgs args) override {
68     delegate_->UpdateLocked(std::move(args));
69   }
70 
ExitIdleLocked()71   void ExitIdleLocked() override { delegate_->ExitIdleLocked(); }
72 
ResetBackoffLocked()73   void ResetBackoffLocked() override { delegate_->ResetBackoffLocked(); }
74 
75  private:
ShutdownLocked()76   void ShutdownLocked() override { delegate_.reset(); }
77 
78   OrphanablePtr<LoadBalancingPolicy> delegate_;
79 };
80 
81 //
82 // CopyMetadataToVector()
83 //
84 
CopyMetadataToVector(LoadBalancingPolicy::MetadataInterface * metadata)85 MetadataVector CopyMetadataToVector(
86     LoadBalancingPolicy::MetadataInterface* metadata) {
87   MetadataVector result;
88   for (const auto& p : *metadata) {
89     result.push_back({std::string(p.first), std::string(p.second)});
90   }
91   return result;
92 }
93 
94 //
95 // TestPickArgsLb
96 //
97 
98 constexpr char kTestPickArgsLbPolicyName[] = "test_pick_args_lb";
99 
100 class TestPickArgsLb : public ForwardingLoadBalancingPolicy {
101  public:
TestPickArgsLb(Args args,TestPickArgsCallback cb)102   TestPickArgsLb(Args args, TestPickArgsCallback cb)
103       : ForwardingLoadBalancingPolicy(
104             absl::make_unique<Helper>(RefCountedPtr<TestPickArgsLb>(this), cb),
105             std::move(args),
106             /*delegate_lb_policy_name=*/"pick_first",
107             /*initial_refcount=*/2) {}
108 
109   ~TestPickArgsLb() override = default;
110 
name() const111   const char* name() const override { return kTestPickArgsLbPolicyName; }
112 
113  private:
114   class Picker : public SubchannelPicker {
115    public:
Picker(std::unique_ptr<SubchannelPicker> delegate_picker,TestPickArgsCallback cb)116     Picker(std::unique_ptr<SubchannelPicker> delegate_picker,
117            TestPickArgsCallback cb)
118         : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
119 
Pick(PickArgs args)120     PickResult Pick(PickArgs args) override {
121       // Report args seen.
122       PickArgsSeen args_seen;
123       args_seen.path = std::string(args.path);
124       args_seen.metadata = CopyMetadataToVector(args.initial_metadata);
125       cb_(args_seen);
126       // Do pick.
127       return delegate_picker_->Pick(args);
128     }
129 
130    private:
131     std::unique_ptr<SubchannelPicker> delegate_picker_;
132     TestPickArgsCallback cb_;
133   };
134 
135   class Helper : public ChannelControlHelper {
136    public:
Helper(RefCountedPtr<TestPickArgsLb> parent,TestPickArgsCallback cb)137     Helper(RefCountedPtr<TestPickArgsLb> parent, TestPickArgsCallback cb)
138         : parent_(std::move(parent)), cb_(std::move(cb)) {}
139 
CreateSubchannel(const grpc_channel_args & args)140     RefCountedPtr<SubchannelInterface> CreateSubchannel(
141         const grpc_channel_args& args) override {
142       return parent_->channel_control_helper()->CreateSubchannel(args);
143     }
144 
UpdateState(grpc_connectivity_state state,std::unique_ptr<SubchannelPicker> picker)145     void UpdateState(grpc_connectivity_state state,
146                      std::unique_ptr<SubchannelPicker> picker) override {
147       parent_->channel_control_helper()->UpdateState(
148           state, absl::make_unique<Picker>(std::move(picker), cb_));
149     }
150 
RequestReresolution()151     void RequestReresolution() override {
152       parent_->channel_control_helper()->RequestReresolution();
153     }
154 
AddTraceEvent(TraceSeverity severity,absl::string_view message)155     void AddTraceEvent(TraceSeverity severity,
156                        absl::string_view message) override {
157       parent_->channel_control_helper()->AddTraceEvent(severity, message);
158     }
159 
160    private:
161     RefCountedPtr<TestPickArgsLb> parent_;
162     TestPickArgsCallback cb_;
163   };
164 };
165 
166 class TestPickArgsLbConfig : public LoadBalancingPolicy::Config {
167  public:
name() const168   const char* name() const override { return kTestPickArgsLbPolicyName; }
169 };
170 
171 class TestPickArgsLbFactory : public LoadBalancingPolicyFactory {
172  public:
TestPickArgsLbFactory(TestPickArgsCallback cb)173   explicit TestPickArgsLbFactory(TestPickArgsCallback cb)
174       : cb_(std::move(cb)) {}
175 
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const176   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
177       LoadBalancingPolicy::Args args) const override {
178     return MakeOrphanable<TestPickArgsLb>(std::move(args), cb_);
179   }
180 
name() const181   const char* name() const override { return kTestPickArgsLbPolicyName; }
182 
ParseLoadBalancingConfig(const Json &,grpc_error **) const183   RefCountedPtr<LoadBalancingPolicy::Config> ParseLoadBalancingConfig(
184       const Json& /*json*/, grpc_error** /*error*/) const override {
185     return MakeRefCounted<TestPickArgsLbConfig>();
186   }
187 
188  private:
189   TestPickArgsCallback cb_;
190 };
191 
192 //
193 // InterceptRecvTrailingMetadataLoadBalancingPolicy
194 //
195 
196 constexpr char kInterceptRecvTrailingMetadataLbPolicyName[] =
197     "intercept_trailing_metadata_lb";
198 
199 class InterceptRecvTrailingMetadataLoadBalancingPolicy
200     : public ForwardingLoadBalancingPolicy {
201  public:
InterceptRecvTrailingMetadataLoadBalancingPolicy(Args args,InterceptRecvTrailingMetadataCallback cb)202   InterceptRecvTrailingMetadataLoadBalancingPolicy(
203       Args args, InterceptRecvTrailingMetadataCallback cb)
204       : ForwardingLoadBalancingPolicy(
205             absl::make_unique<Helper>(
206                 RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
207                     this),
208                 std::move(cb)),
209             std::move(args),
210             /*delegate_lb_policy_name=*/"pick_first",
211             /*initial_refcount=*/2) {}
212 
213   ~InterceptRecvTrailingMetadataLoadBalancingPolicy() override = default;
214 
name() const215   const char* name() const override {
216     return kInterceptRecvTrailingMetadataLbPolicyName;
217   }
218 
219  private:
220   class Picker : public SubchannelPicker {
221    public:
Picker(std::unique_ptr<SubchannelPicker> delegate_picker,InterceptRecvTrailingMetadataCallback cb)222     Picker(std::unique_ptr<SubchannelPicker> delegate_picker,
223            InterceptRecvTrailingMetadataCallback cb)
224         : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
225 
Pick(PickArgs args)226     PickResult Pick(PickArgs args) override {
227       // Do pick.
228       PickResult result = delegate_picker_->Pick(args);
229       // Intercept trailing metadata.
230       if (result.type == PickResult::PICK_COMPLETE &&
231           result.subchannel != nullptr) {
232         new (args.call_state->Alloc(sizeof(TrailingMetadataHandler)))
233             TrailingMetadataHandler(&result, cb_);
234       }
235       return result;
236     }
237 
238    private:
239     std::unique_ptr<SubchannelPicker> delegate_picker_;
240     InterceptRecvTrailingMetadataCallback cb_;
241   };
242 
243   class Helper : public ChannelControlHelper {
244    public:
Helper(RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent,InterceptRecvTrailingMetadataCallback cb)245     Helper(
246         RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent,
247         InterceptRecvTrailingMetadataCallback cb)
248         : parent_(std::move(parent)), cb_(std::move(cb)) {}
249 
CreateSubchannel(const grpc_channel_args & args)250     RefCountedPtr<SubchannelInterface> CreateSubchannel(
251         const grpc_channel_args& args) override {
252       return parent_->channel_control_helper()->CreateSubchannel(args);
253     }
254 
UpdateState(grpc_connectivity_state state,std::unique_ptr<SubchannelPicker> picker)255     void UpdateState(grpc_connectivity_state state,
256                      std::unique_ptr<SubchannelPicker> picker) override {
257       parent_->channel_control_helper()->UpdateState(
258           state, absl::make_unique<Picker>(std::move(picker), cb_));
259     }
260 
RequestReresolution()261     void RequestReresolution() override {
262       parent_->channel_control_helper()->RequestReresolution();
263     }
264 
AddTraceEvent(TraceSeverity severity,absl::string_view message)265     void AddTraceEvent(TraceSeverity severity,
266                        absl::string_view message) override {
267       parent_->channel_control_helper()->AddTraceEvent(severity, message);
268     }
269 
270    private:
271     RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent_;
272     InterceptRecvTrailingMetadataCallback cb_;
273   };
274 
275   class TrailingMetadataHandler {
276    public:
TrailingMetadataHandler(PickResult * result,InterceptRecvTrailingMetadataCallback cb)277     TrailingMetadataHandler(PickResult* result,
278                             InterceptRecvTrailingMetadataCallback cb)
279         : cb_(std::move(cb)) {
280       result->recv_trailing_metadata_ready = [this](grpc_error* error,
281                                                     MetadataInterface* metadata,
282                                                     CallState* call_state) {
283         RecordRecvTrailingMetadata(error, metadata, call_state);
284       };
285     }
286 
287    private:
RecordRecvTrailingMetadata(grpc_error *,MetadataInterface * recv_trailing_metadata,CallState * call_state)288     void RecordRecvTrailingMetadata(grpc_error* /*error*/,
289                                     MetadataInterface* recv_trailing_metadata,
290                                     CallState* call_state) {
291       TrailingMetadataArgsSeen args_seen;
292       args_seen.backend_metric_data = call_state->GetBackendMetricData();
293       GPR_ASSERT(recv_trailing_metadata != nullptr);
294       args_seen.metadata = CopyMetadataToVector(recv_trailing_metadata);
295       cb_(args_seen);
296       this->~TrailingMetadataHandler();
297     }
298 
299     InterceptRecvTrailingMetadataCallback cb_;
300   };
301 };
302 
303 class InterceptTrailingConfig : public LoadBalancingPolicy::Config {
304  public:
name() const305   const char* name() const override {
306     return kInterceptRecvTrailingMetadataLbPolicyName;
307   }
308 };
309 
310 class InterceptTrailingFactory : public LoadBalancingPolicyFactory {
311  public:
InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb)312   explicit InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb)
313       : cb_(std::move(cb)) {}
314 
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const315   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
316       LoadBalancingPolicy::Args args) const override {
317     return MakeOrphanable<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
318         std::move(args), cb_);
319   }
320 
name() const321   const char* name() const override {
322     return kInterceptRecvTrailingMetadataLbPolicyName;
323   }
324 
ParseLoadBalancingConfig(const Json &,grpc_error **) const325   RefCountedPtr<LoadBalancingPolicy::Config> ParseLoadBalancingConfig(
326       const Json& /*json*/, grpc_error** /*error*/) const override {
327     return MakeRefCounted<InterceptTrailingConfig>();
328   }
329 
330  private:
331   InterceptRecvTrailingMetadataCallback cb_;
332 };
333 
334 }  // namespace
335 
RegisterTestPickArgsLoadBalancingPolicy(TestPickArgsCallback cb)336 void RegisterTestPickArgsLoadBalancingPolicy(TestPickArgsCallback cb) {
337   LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory(
338       absl::make_unique<TestPickArgsLbFactory>(std::move(cb)));
339 }
340 
RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(InterceptRecvTrailingMetadataCallback cb)341 void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(
342     InterceptRecvTrailingMetadataCallback cb) {
343   LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory(
344       absl::make_unique<InterceptTrailingFactory>(std::move(cb)));
345 }
346 
347 }  // namespace grpc_core
348