1 //
2 //
3 // Copyright 2023 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "test/cpp/interop/backend_metrics_lb_policy.h"
20
21 #include <grpc/support/port_platform.h>
22
23 #include "absl/log/check.h"
24 #include "absl/log/log.h"
25 #include "absl/strings/str_format.h"
26 #include "src/core/lib/iomgr/pollset_set.h"
27 #include "src/core/load_balancing/delegating_helper.h"
28 #include "src/core/load_balancing/oob_backend_metric.h"
29
30 namespace grpc {
31 namespace testing {
32
33 namespace {
34
35 using grpc_core::CoreConfiguration;
36 using grpc_core::LoadBalancingPolicy;
37 using grpc_core::MakeRefCounted;
38 using grpc_core::OrphanablePtr;
39 using grpc_core::RefCountedPtr;
40
41 constexpr absl::string_view kBackendMetricsLbPolicyName =
42 "test_backend_metrics_load_balancer";
43 constexpr absl::string_view kMetricsTrackerArgument = "orca_metrics_tracker";
44
BackendMetricDataToOrcaLoadReport(const grpc_core::BackendMetricData * backend_metric_data)45 LoadReportTracker::LoadReportEntry BackendMetricDataToOrcaLoadReport(
46 const grpc_core::BackendMetricData* backend_metric_data) {
47 if (backend_metric_data == nullptr) {
48 return absl::nullopt;
49 }
50 TestOrcaReport load_report;
51 load_report.set_cpu_utilization(backend_metric_data->cpu_utilization);
52 load_report.set_memory_utilization(backend_metric_data->mem_utilization);
53 for (const auto& p : backend_metric_data->request_cost) {
54 std::string name(p.first);
55 (*load_report.mutable_request_cost())[name] = p.second;
56 }
57 for (const auto& p : backend_metric_data->utilization) {
58 std::string name(p.first);
59 (*load_report.mutable_utilization())[name] = p.second;
60 }
61 return load_report;
62 }
63
64 class BackendMetricsLbPolicy : public LoadBalancingPolicy {
65 public:
BackendMetricsLbPolicy(Args args)66 explicit BackendMetricsLbPolicy(Args args)
67 : LoadBalancingPolicy(std::move(args), /*initial_refcount=*/2) {
68 load_report_tracker_ =
69 channel_args().GetPointer<LoadReportTracker>(kMetricsTrackerArgument);
70 CHECK_NE(load_report_tracker_, nullptr);
71 Args delegate_args;
72 delegate_args.work_serializer = work_serializer();
73 delegate_args.args = channel_args();
74 delegate_args.channel_control_helper =
75 std::make_unique<Helper>(RefCountedPtr<BackendMetricsLbPolicy>(this));
76 delegate_ =
77 CoreConfiguration::Get().lb_policy_registry().CreateLoadBalancingPolicy(
78 "pick_first", std::move(delegate_args));
79 grpc_pollset_set_add_pollset_set(delegate_->interested_parties(),
80 interested_parties());
81 }
82
83 ~BackendMetricsLbPolicy() override = default;
84
name() const85 absl::string_view name() const override {
86 return kBackendMetricsLbPolicyName;
87 }
88
UpdateLocked(UpdateArgs args)89 absl::Status UpdateLocked(UpdateArgs args) override {
90 auto config =
91 CoreConfiguration::Get().lb_policy_registry().ParseLoadBalancingConfig(
92 grpc_core::Json::FromArray({grpc_core::Json::FromObject(
93 {{"pick_first", grpc_core::Json::FromObject({})}})}));
94 args.config = std::move(config.value());
95 return delegate_->UpdateLocked(std::move(args));
96 }
97
ExitIdleLocked()98 void ExitIdleLocked() override { delegate_->ExitIdleLocked(); }
99
ResetBackoffLocked()100 void ResetBackoffLocked() override { delegate_->ResetBackoffLocked(); }
101
102 private:
103 class Picker : public SubchannelPicker {
104 public:
Picker(RefCountedPtr<SubchannelPicker> delegate_picker,LoadReportTracker * load_report_tracker)105 Picker(RefCountedPtr<SubchannelPicker> delegate_picker,
106 LoadReportTracker* load_report_tracker)
107 : delegate_picker_(std::move(delegate_picker)),
108 load_report_tracker_(load_report_tracker) {}
109
Pick(PickArgs args)110 PickResult Pick(PickArgs args) override {
111 // Do pick.
112 PickResult result = delegate_picker_->Pick(args);
113 // Intercept trailing metadata.
114 auto* complete_pick = absl::get_if<PickResult::Complete>(&result.result);
115 if (complete_pick != nullptr) {
116 complete_pick->subchannel_call_tracker =
117 std::make_unique<SubchannelCallTracker>(load_report_tracker_);
118 }
119 return result;
120 }
121
122 private:
123 RefCountedPtr<SubchannelPicker> delegate_picker_;
124 LoadReportTracker* load_report_tracker_;
125 };
126
127 class OobMetricWatcher : public grpc_core::OobBackendMetricWatcher {
128 public:
OobMetricWatcher(LoadReportTracker * load_report_tracker)129 explicit OobMetricWatcher(LoadReportTracker* load_report_tracker)
130 : load_report_tracker_(load_report_tracker) {}
131
132 private:
OnBackendMetricReport(const grpc_core::BackendMetricData & backend_metric_data)133 void OnBackendMetricReport(
134 const grpc_core::BackendMetricData& backend_metric_data) override {
135 load_report_tracker_->RecordOobLoadReport(backend_metric_data);
136 }
137
138 LoadReportTracker* load_report_tracker_;
139 };
140
141 class Helper : public ParentOwningDelegatingChannelControlHelper<
142 BackendMetricsLbPolicy> {
143 public:
Helper(RefCountedPtr<BackendMetricsLbPolicy> parent)144 explicit Helper(RefCountedPtr<BackendMetricsLbPolicy> parent)
145 : ParentOwningDelegatingChannelControlHelper(std::move(parent)) {}
146
CreateSubchannel(const grpc_resolved_address & address,const grpc_core::ChannelArgs & per_address_args,const grpc_core::ChannelArgs & args)147 RefCountedPtr<grpc_core::SubchannelInterface> CreateSubchannel(
148 const grpc_resolved_address& address,
149 const grpc_core::ChannelArgs& per_address_args,
150 const grpc_core::ChannelArgs& args) override {
151 auto subchannel =
152 parent_helper()->CreateSubchannel(address, per_address_args, args);
153 subchannel->AddDataWatcher(MakeOobBackendMetricWatcher(
154 grpc_core::Duration::Seconds(1),
155 std::make_unique<OobMetricWatcher>(parent()->load_report_tracker_)));
156 return subchannel;
157 }
158
UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<SubchannelPicker> picker)159 void UpdateState(grpc_connectivity_state state, const absl::Status& status,
160 RefCountedPtr<SubchannelPicker> picker) override {
161 parent_helper()->UpdateState(
162 state, status,
163 MakeRefCounted<Picker>(std::move(picker),
164 parent()->load_report_tracker_));
165 }
166 };
167
168 class SubchannelCallTracker : public SubchannelCallTrackerInterface {
169 public:
SubchannelCallTracker(LoadReportTracker * load_report_tracker)170 explicit SubchannelCallTracker(LoadReportTracker* load_report_tracker)
171 : load_report_tracker_(load_report_tracker) {}
172
Start()173 void Start() override {}
174
Finish(FinishArgs args)175 void Finish(FinishArgs args) override {
176 load_report_tracker_->RecordPerRpcLoadReport(
177 args.backend_metric_accessor->GetBackendMetricData());
178 }
179
180 private:
181 LoadReportTracker* load_report_tracker_;
182 };
183
ShutdownLocked()184 void ShutdownLocked() override {
185 grpc_pollset_set_del_pollset_set(delegate_->interested_parties(),
186 interested_parties());
187 delegate_.reset();
188 }
189
190 OrphanablePtr<LoadBalancingPolicy> delegate_;
191 LoadReportTracker* load_report_tracker_;
192 };
193
194 class BackendMetricsLbPolicyFactory
195 : public grpc_core::LoadBalancingPolicyFactory {
196 private:
197 class BackendMetricsLbPolicyFactoryConfig
198 : public LoadBalancingPolicy::Config {
199 private:
name() const200 absl::string_view name() const override {
201 return kBackendMetricsLbPolicyName;
202 }
203 };
204
name() const205 absl::string_view name() const override {
206 return kBackendMetricsLbPolicyName;
207 }
208
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const209 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
210 LoadBalancingPolicy::Args args) const override {
211 return grpc_core::MakeOrphanable<BackendMetricsLbPolicy>(std::move(args));
212 }
213
214 absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const grpc_core::Json &) const215 ParseLoadBalancingConfig(const grpc_core::Json& /*json*/) const override {
216 return MakeRefCounted<BackendMetricsLbPolicyFactoryConfig>();
217 }
218 };
219 } // namespace
220
RegisterBackendMetricsLbPolicy(CoreConfiguration::Builder * builder)221 void RegisterBackendMetricsLbPolicy(CoreConfiguration::Builder* builder) {
222 builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
223 std::make_unique<BackendMetricsLbPolicyFactory>());
224 }
225
RecordPerRpcLoadReport(const grpc_core::BackendMetricData * backend_metric_data)226 void LoadReportTracker::RecordPerRpcLoadReport(
227 const grpc_core::BackendMetricData* backend_metric_data) {
228 grpc_core::MutexLock lock(&load_reports_mu_);
229 per_rpc_load_reports_.emplace_back(
230 BackendMetricDataToOrcaLoadReport(backend_metric_data));
231 }
232
RecordOobLoadReport(const grpc_core::BackendMetricData & oob_metric_data)233 void LoadReportTracker::RecordOobLoadReport(
234 const grpc_core::BackendMetricData& oob_metric_data) {
235 grpc_core::MutexLock lock(&load_reports_mu_);
236 oob_load_reports_.emplace_back(
237 *BackendMetricDataToOrcaLoadReport(&oob_metric_data));
238 load_reports_cv_.Signal();
239 }
240
241 absl::optional<LoadReportTracker::LoadReportEntry>
GetNextLoadReport()242 LoadReportTracker::GetNextLoadReport() {
243 grpc_core::MutexLock lock(&load_reports_mu_);
244 if (per_rpc_load_reports_.empty()) {
245 return absl::nullopt;
246 }
247 auto report = std::move(per_rpc_load_reports_.front());
248 per_rpc_load_reports_.pop_front();
249 return report;
250 }
251
WaitForOobLoadReport(const std::function<bool (const TestOrcaReport &)> & predicate,absl::Duration poll_timeout,size_t max_attempts)252 LoadReportTracker::LoadReportEntry LoadReportTracker::WaitForOobLoadReport(
253 const std::function<bool(const TestOrcaReport&)>& predicate,
254 absl::Duration poll_timeout, size_t max_attempts) {
255 grpc_core::MutexLock lock(&load_reports_mu_);
256 // This condition will be called under lock
257 for (size_t i = 0; i < max_attempts; i++) {
258 if (oob_load_reports_.empty()) {
259 load_reports_cv_.WaitWithTimeout(&load_reports_mu_, poll_timeout);
260 if (oob_load_reports_.empty()) {
261 return absl::nullopt;
262 }
263 }
264 auto report = std::move(oob_load_reports_.front());
265 oob_load_reports_.pop_front();
266 if (predicate(report)) {
267 VLOG(2) << "Report #" << (i + 1) << " matched";
268 return report;
269 }
270 }
271 return absl::nullopt;
272 }
273
ResetCollectedLoadReports()274 void LoadReportTracker::ResetCollectedLoadReports() {
275 grpc_core::MutexLock lock(&load_reports_mu_);
276 per_rpc_load_reports_.clear();
277 oob_load_reports_.clear();
278 }
279
GetChannelArguments()280 ChannelArguments LoadReportTracker::GetChannelArguments() {
281 ChannelArguments arguments;
282 arguments.SetPointer(std::string(kMetricsTrackerArgument), this);
283 return arguments;
284 }
285
286 } // namespace testing
287 } // namespace grpc
288