1 /*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include "test/core/util/test_lb_policies.h"
20
21 #include <string>
22
23 #include <grpc/support/log.h>
24
25 #include "src/core/ext/filters/client_channel/lb_policy.h"
26 #include "src/core/ext/filters/client_channel/lb_policy_registry.h"
27 #include "src/core/lib/channel/channel_args.h"
28 #include "src/core/lib/channel/channelz.h"
29 #include "src/core/lib/debug/trace.h"
30 #include "src/core/lib/gprpp/memory.h"
31 #include "src/core/lib/gprpp/orphanable.h"
32 #include "src/core/lib/gprpp/ref_counted_ptr.h"
33 #include "src/core/lib/iomgr/closure.h"
34 #include "src/core/lib/iomgr/combiner.h"
35 #include "src/core/lib/iomgr/error.h"
36 #include "src/core/lib/iomgr/pollset_set.h"
37 #include "src/core/lib/json/json.h"
38 #include "src/core/lib/transport/connectivity_state.h"
39
40 namespace grpc_core {
41
42 namespace {
43
44 //
45 // ForwardingLoadBalancingPolicy
46 //
47
48 // A minimal forwarding class to avoid implementing a standalone test LB.
49 class ForwardingLoadBalancingPolicy : public LoadBalancingPolicy {
50 public:
ForwardingLoadBalancingPolicy(std::unique_ptr<ChannelControlHelper> delegating_helper,Args args,const std::string & delegate_policy_name,intptr_t initial_refcount=1)51 ForwardingLoadBalancingPolicy(
52 std::unique_ptr<ChannelControlHelper> delegating_helper, Args args,
53 const std::string& delegate_policy_name, intptr_t initial_refcount = 1)
54 : LoadBalancingPolicy(std::move(args), initial_refcount) {
55 Args delegate_args;
56 delegate_args.work_serializer = work_serializer();
57 delegate_args.channel_control_helper = std::move(delegating_helper);
58 delegate_args.args = args.args;
59 delegate_ = LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy(
60 delegate_policy_name.c_str(), std::move(delegate_args));
61 grpc_pollset_set_add_pollset_set(delegate_->interested_parties(),
62 interested_parties());
63 }
64
65 ~ForwardingLoadBalancingPolicy() override = default;
66
UpdateLocked(UpdateArgs args)67 void UpdateLocked(UpdateArgs args) override {
68 delegate_->UpdateLocked(std::move(args));
69 }
70
ExitIdleLocked()71 void ExitIdleLocked() override { delegate_->ExitIdleLocked(); }
72
ResetBackoffLocked()73 void ResetBackoffLocked() override { delegate_->ResetBackoffLocked(); }
74
75 private:
ShutdownLocked()76 void ShutdownLocked() override { delegate_.reset(); }
77
78 OrphanablePtr<LoadBalancingPolicy> delegate_;
79 };
80
81 //
82 // CopyMetadataToVector()
83 //
84
CopyMetadataToVector(LoadBalancingPolicy::MetadataInterface * metadata)85 MetadataVector CopyMetadataToVector(
86 LoadBalancingPolicy::MetadataInterface* metadata) {
87 MetadataVector result;
88 for (const auto& p : *metadata) {
89 result.push_back({std::string(p.first), std::string(p.second)});
90 }
91 return result;
92 }
93
94 //
95 // TestPickArgsLb
96 //
97
98 constexpr char kTestPickArgsLbPolicyName[] = "test_pick_args_lb";
99
100 class TestPickArgsLb : public ForwardingLoadBalancingPolicy {
101 public:
TestPickArgsLb(Args args,TestPickArgsCallback cb)102 TestPickArgsLb(Args args, TestPickArgsCallback cb)
103 : ForwardingLoadBalancingPolicy(
104 absl::make_unique<Helper>(RefCountedPtr<TestPickArgsLb>(this), cb),
105 std::move(args),
106 /*delegate_lb_policy_name=*/"pick_first",
107 /*initial_refcount=*/2) {}
108
109 ~TestPickArgsLb() override = default;
110
name() const111 const char* name() const override { return kTestPickArgsLbPolicyName; }
112
113 private:
114 class Picker : public SubchannelPicker {
115 public:
Picker(std::unique_ptr<SubchannelPicker> delegate_picker,TestPickArgsCallback cb)116 Picker(std::unique_ptr<SubchannelPicker> delegate_picker,
117 TestPickArgsCallback cb)
118 : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
119
Pick(PickArgs args)120 PickResult Pick(PickArgs args) override {
121 // Report args seen.
122 PickArgsSeen args_seen;
123 args_seen.path = std::string(args.path);
124 args_seen.metadata = CopyMetadataToVector(args.initial_metadata);
125 cb_(args_seen);
126 // Do pick.
127 return delegate_picker_->Pick(args);
128 }
129
130 private:
131 std::unique_ptr<SubchannelPicker> delegate_picker_;
132 TestPickArgsCallback cb_;
133 };
134
135 class Helper : public ChannelControlHelper {
136 public:
Helper(RefCountedPtr<TestPickArgsLb> parent,TestPickArgsCallback cb)137 Helper(RefCountedPtr<TestPickArgsLb> parent, TestPickArgsCallback cb)
138 : parent_(std::move(parent)), cb_(std::move(cb)) {}
139
CreateSubchannel(const grpc_channel_args & args)140 RefCountedPtr<SubchannelInterface> CreateSubchannel(
141 const grpc_channel_args& args) override {
142 return parent_->channel_control_helper()->CreateSubchannel(args);
143 }
144
UpdateState(grpc_connectivity_state state,std::unique_ptr<SubchannelPicker> picker)145 void UpdateState(grpc_connectivity_state state,
146 std::unique_ptr<SubchannelPicker> picker) override {
147 parent_->channel_control_helper()->UpdateState(
148 state, absl::make_unique<Picker>(std::move(picker), cb_));
149 }
150
RequestReresolution()151 void RequestReresolution() override {
152 parent_->channel_control_helper()->RequestReresolution();
153 }
154
AddTraceEvent(TraceSeverity severity,absl::string_view message)155 void AddTraceEvent(TraceSeverity severity,
156 absl::string_view message) override {
157 parent_->channel_control_helper()->AddTraceEvent(severity, message);
158 }
159
160 private:
161 RefCountedPtr<TestPickArgsLb> parent_;
162 TestPickArgsCallback cb_;
163 };
164 };
165
166 class TestPickArgsLbConfig : public LoadBalancingPolicy::Config {
167 public:
name() const168 const char* name() const override { return kTestPickArgsLbPolicyName; }
169 };
170
171 class TestPickArgsLbFactory : public LoadBalancingPolicyFactory {
172 public:
TestPickArgsLbFactory(TestPickArgsCallback cb)173 explicit TestPickArgsLbFactory(TestPickArgsCallback cb)
174 : cb_(std::move(cb)) {}
175
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const176 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
177 LoadBalancingPolicy::Args args) const override {
178 return MakeOrphanable<TestPickArgsLb>(std::move(args), cb_);
179 }
180
name() const181 const char* name() const override { return kTestPickArgsLbPolicyName; }
182
ParseLoadBalancingConfig(const Json &,grpc_error **) const183 RefCountedPtr<LoadBalancingPolicy::Config> ParseLoadBalancingConfig(
184 const Json& /*json*/, grpc_error** /*error*/) const override {
185 return MakeRefCounted<TestPickArgsLbConfig>();
186 }
187
188 private:
189 TestPickArgsCallback cb_;
190 };
191
192 //
193 // InterceptRecvTrailingMetadataLoadBalancingPolicy
194 //
195
196 constexpr char kInterceptRecvTrailingMetadataLbPolicyName[] =
197 "intercept_trailing_metadata_lb";
198
199 class InterceptRecvTrailingMetadataLoadBalancingPolicy
200 : public ForwardingLoadBalancingPolicy {
201 public:
InterceptRecvTrailingMetadataLoadBalancingPolicy(Args args,InterceptRecvTrailingMetadataCallback cb)202 InterceptRecvTrailingMetadataLoadBalancingPolicy(
203 Args args, InterceptRecvTrailingMetadataCallback cb)
204 : ForwardingLoadBalancingPolicy(
205 absl::make_unique<Helper>(
206 RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
207 this),
208 std::move(cb)),
209 std::move(args),
210 /*delegate_lb_policy_name=*/"pick_first",
211 /*initial_refcount=*/2) {}
212
213 ~InterceptRecvTrailingMetadataLoadBalancingPolicy() override = default;
214
name() const215 const char* name() const override {
216 return kInterceptRecvTrailingMetadataLbPolicyName;
217 }
218
219 private:
220 class Picker : public SubchannelPicker {
221 public:
Picker(std::unique_ptr<SubchannelPicker> delegate_picker,InterceptRecvTrailingMetadataCallback cb)222 Picker(std::unique_ptr<SubchannelPicker> delegate_picker,
223 InterceptRecvTrailingMetadataCallback cb)
224 : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
225
Pick(PickArgs args)226 PickResult Pick(PickArgs args) override {
227 // Do pick.
228 PickResult result = delegate_picker_->Pick(args);
229 // Intercept trailing metadata.
230 if (result.type == PickResult::PICK_COMPLETE &&
231 result.subchannel != nullptr) {
232 new (args.call_state->Alloc(sizeof(TrailingMetadataHandler)))
233 TrailingMetadataHandler(&result, cb_);
234 }
235 return result;
236 }
237
238 private:
239 std::unique_ptr<SubchannelPicker> delegate_picker_;
240 InterceptRecvTrailingMetadataCallback cb_;
241 };
242
243 class Helper : public ChannelControlHelper {
244 public:
Helper(RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent,InterceptRecvTrailingMetadataCallback cb)245 Helper(
246 RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent,
247 InterceptRecvTrailingMetadataCallback cb)
248 : parent_(std::move(parent)), cb_(std::move(cb)) {}
249
CreateSubchannel(const grpc_channel_args & args)250 RefCountedPtr<SubchannelInterface> CreateSubchannel(
251 const grpc_channel_args& args) override {
252 return parent_->channel_control_helper()->CreateSubchannel(args);
253 }
254
UpdateState(grpc_connectivity_state state,std::unique_ptr<SubchannelPicker> picker)255 void UpdateState(grpc_connectivity_state state,
256 std::unique_ptr<SubchannelPicker> picker) override {
257 parent_->channel_control_helper()->UpdateState(
258 state, absl::make_unique<Picker>(std::move(picker), cb_));
259 }
260
RequestReresolution()261 void RequestReresolution() override {
262 parent_->channel_control_helper()->RequestReresolution();
263 }
264
AddTraceEvent(TraceSeverity severity,absl::string_view message)265 void AddTraceEvent(TraceSeverity severity,
266 absl::string_view message) override {
267 parent_->channel_control_helper()->AddTraceEvent(severity, message);
268 }
269
270 private:
271 RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent_;
272 InterceptRecvTrailingMetadataCallback cb_;
273 };
274
275 class TrailingMetadataHandler {
276 public:
TrailingMetadataHandler(PickResult * result,InterceptRecvTrailingMetadataCallback cb)277 TrailingMetadataHandler(PickResult* result,
278 InterceptRecvTrailingMetadataCallback cb)
279 : cb_(std::move(cb)) {
280 result->recv_trailing_metadata_ready = [this](grpc_error* error,
281 MetadataInterface* metadata,
282 CallState* call_state) {
283 RecordRecvTrailingMetadata(error, metadata, call_state);
284 };
285 }
286
287 private:
RecordRecvTrailingMetadata(grpc_error *,MetadataInterface * recv_trailing_metadata,CallState * call_state)288 void RecordRecvTrailingMetadata(grpc_error* /*error*/,
289 MetadataInterface* recv_trailing_metadata,
290 CallState* call_state) {
291 TrailingMetadataArgsSeen args_seen;
292 args_seen.backend_metric_data = call_state->GetBackendMetricData();
293 GPR_ASSERT(recv_trailing_metadata != nullptr);
294 args_seen.metadata = CopyMetadataToVector(recv_trailing_metadata);
295 cb_(args_seen);
296 this->~TrailingMetadataHandler();
297 }
298
299 InterceptRecvTrailingMetadataCallback cb_;
300 };
301 };
302
303 class InterceptTrailingConfig : public LoadBalancingPolicy::Config {
304 public:
name() const305 const char* name() const override {
306 return kInterceptRecvTrailingMetadataLbPolicyName;
307 }
308 };
309
310 class InterceptTrailingFactory : public LoadBalancingPolicyFactory {
311 public:
InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb)312 explicit InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb)
313 : cb_(std::move(cb)) {}
314
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const315 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
316 LoadBalancingPolicy::Args args) const override {
317 return MakeOrphanable<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
318 std::move(args), cb_);
319 }
320
name() const321 const char* name() const override {
322 return kInterceptRecvTrailingMetadataLbPolicyName;
323 }
324
ParseLoadBalancingConfig(const Json &,grpc_error **) const325 RefCountedPtr<LoadBalancingPolicy::Config> ParseLoadBalancingConfig(
326 const Json& /*json*/, grpc_error** /*error*/) const override {
327 return MakeRefCounted<InterceptTrailingConfig>();
328 }
329
330 private:
331 InterceptRecvTrailingMetadataCallback cb_;
332 };
333
334 } // namespace
335
RegisterTestPickArgsLoadBalancingPolicy(TestPickArgsCallback cb)336 void RegisterTestPickArgsLoadBalancingPolicy(TestPickArgsCallback cb) {
337 LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory(
338 absl::make_unique<TestPickArgsLbFactory>(std::move(cb)));
339 }
340
RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(InterceptRecvTrailingMetadataCallback cb)341 void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(
342 InterceptRecvTrailingMetadataCallback cb) {
343 LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory(
344 absl::make_unique<InterceptTrailingFactory>(std::move(cb)));
345 }
346
347 } // namespace grpc_core
348