1 // 2 // Copyright 2022 gRPC authors. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 17 #ifndef GRPC_TEST_CORE_LOAD_BALANCING_LB_POLICY_TEST_LIB_H 18 #define GRPC_TEST_CORE_LOAD_BALANCING_LB_POLICY_TEST_LIB_H 19 20 #include <grpc/event_engine/event_engine.h> 21 #include <grpc/grpc.h> 22 #include <grpc/support/alloc.h> 23 #include <grpc/support/port_platform.h> 24 #include <inttypes.h> 25 #include <stddef.h> 26 27 #include <algorithm> 28 #include <chrono> 29 #include <deque> 30 #include <functional> 31 #include <map> 32 #include <memory> 33 #include <set> 34 #include <string> 35 #include <tuple> 36 #include <type_traits> 37 #include <utility> 38 #include <vector> 39 40 #include "absl/base/thread_annotations.h" 41 #include "absl/functional/any_invocable.h" 42 #include "absl/log/check.h" 43 #include "absl/log/log.h" 44 #include "absl/status/status.h" 45 #include "absl/status/statusor.h" 46 #include "absl/strings/str_format.h" 47 #include "absl/strings/str_join.h" 48 #include "absl/strings/string_view.h" 49 #include "absl/synchronization/notification.h" 50 #include "absl/types/optional.h" 51 #include "absl/types/span.h" 52 #include "absl/types/variant.h" 53 #include "gmock/gmock.h" 54 #include "gtest/gtest.h" 55 #include "src/core/client_channel/client_channel_internal.h" 56 #include "src/core/client_channel/subchannel_interface_internal.h" 57 #include "src/core/client_channel/subchannel_pool_interface.h" 58 #include "src/core/config/core_configuration.h" 59 #include "src/core/lib/address_utils/parse_address.h" 60 #include "src/core/lib/address_utils/sockaddr_utils.h" 61 #include "src/core/lib/channel/channel_args.h" 62 #include "src/core/lib/event_engine/default_event_engine.h" 63 #include "src/core/lib/experiments/experiments.h" 64 #include "src/core/lib/iomgr/exec_ctx.h" 65 #include "src/core/lib/iomgr/resolved_address.h" 66 #include "src/core/lib/iomgr/timer_manager.h" 67 #include "src/core/lib/security/credentials/credentials.h" 68 #include "src/core/lib/transport/connectivity_state.h" 69 #include "src/core/load_balancing/backend_metric_data.h" 70 #include "src/core/load_balancing/health_check_client_internal.h" 71 #include "src/core/load_balancing/lb_policy.h" 72 #include "src/core/load_balancing/lb_policy_registry.h" 73 #include "src/core/load_balancing/oob_backend_metric.h" 74 #include "src/core/load_balancing/oob_backend_metric_internal.h" 75 #include "src/core/load_balancing/subchannel_interface.h" 76 #include "src/core/resolver/endpoint_addresses.h" 77 #include "src/core/service_config/service_config_call_data.h" 78 #include "src/core/util/debug_location.h" 79 #include "src/core/util/json/json.h" 80 #include "src/core/util/match.h" 81 #include "src/core/util/orphanable.h" 82 #include "src/core/util/ref_counted_ptr.h" 83 #include "src/core/util/sync.h" 84 #include "src/core/util/time.h" 85 #include "src/core/util/unique_type_name.h" 86 #include "src/core/util/uri.h" 87 #include "src/core/util/work_serializer.h" 88 #include "test/core/event_engine/event_engine_test_utils.h" 89 #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" 90 #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" 91 92 namespace grpc_core { 93 namespace testing { 94 95 class LoadBalancingPolicyTest : public ::testing::Test { 96 protected: 97 using EventEngine = grpc_event_engine::experimental::EventEngine; 98 using FuzzingEventEngine = 99 grpc_event_engine::experimental::FuzzingEventEngine; 100 101 using CallAttributes = 102 std::vector<ServiceConfigCallData::CallAttributeInterface*>; 103 104 // Channel-level subchannel state for a specific address and channel args. 105 // This is analogous to the real subchannel in the ClientChannel code. 106 class SubchannelState { 107 public: 108 // A fake SubchannelInterface object, to be returned to the LB 109 // policy when it calls the helper's CreateSubchannel() method. 110 // There may be multiple FakeSubchannel objects associated with a 111 // given SubchannelState object. 112 class FakeSubchannel : public SubchannelInterface { 113 public: FakeSubchannel(SubchannelState * state)114 explicit FakeSubchannel(SubchannelState* state) : state_(state) {} 115 ~FakeSubchannel()116 ~FakeSubchannel() override { 117 if (orca_watcher_ != nullptr) { 118 MutexLock lock(&state_->backend_metric_watcher_mu_); 119 state_->orca_watchers_.erase(orca_watcher_.get()); 120 } 121 for (const auto& p : watcher_map_) { 122 state_->state_tracker_.RemoveWatcher(p.second); 123 } 124 } 125 state()126 SubchannelState* state() const { return state_; } 127 address()128 std::string address() const override { return state_->address_; } 129 130 private: 131 // Converts between 132 // SubchannelInterface::ConnectivityStateWatcherInterface and 133 // ConnectivityStateWatcherInterface. 134 // 135 // We support both unique_ptr<> and shared_ptr<>, since raw 136 // connectivity watches use the latter but health watches use the 137 // former. 138 // TODO(roth): Clean this up. 139 class WatcherWrapper : public AsyncConnectivityStateWatcherInterface { 140 public: WatcherWrapper(std::shared_ptr<WorkSerializer> work_serializer,std::unique_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> watcher)141 WatcherWrapper( 142 std::shared_ptr<WorkSerializer> work_serializer, 143 std::unique_ptr< 144 SubchannelInterface::ConnectivityStateWatcherInterface> 145 watcher) 146 : AsyncConnectivityStateWatcherInterface( 147 std::move(work_serializer)), 148 watcher_(std::move(watcher)) {} 149 WatcherWrapper(std::shared_ptr<WorkSerializer> work_serializer,std::shared_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> watcher)150 WatcherWrapper( 151 std::shared_ptr<WorkSerializer> work_serializer, 152 std::shared_ptr< 153 SubchannelInterface::ConnectivityStateWatcherInterface> 154 watcher) 155 : AsyncConnectivityStateWatcherInterface( 156 std::move(work_serializer)), 157 watcher_(std::move(watcher)) {} 158 OnConnectivityStateChange(grpc_connectivity_state new_state,const absl::Status & status)159 void OnConnectivityStateChange(grpc_connectivity_state new_state, 160 const absl::Status& status) override { 161 LOG(INFO) << "notifying watcher: state=" 162 << ConnectivityStateName(new_state) << " status=" << status; 163 watcher_->OnConnectivityStateChange(new_state, status); 164 } 165 166 private: 167 std::shared_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> 168 watcher_; 169 }; 170 WatchConnectivityState(std::unique_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> watcher)171 void WatchConnectivityState( 172 std::unique_ptr< 173 SubchannelInterface::ConnectivityStateWatcherInterface> 174 watcher) override 175 ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) { 176 auto* watcher_ptr = watcher.get(); 177 auto watcher_wrapper = MakeOrphanable<WatcherWrapper>( 178 state_->work_serializer(), std::move(watcher)); 179 watcher_map_[watcher_ptr] = watcher_wrapper.get(); 180 state_->state_tracker_.AddWatcher(GRPC_CHANNEL_SHUTDOWN, 181 std::move(watcher_wrapper)); 182 } 183 CancelConnectivityStateWatch(ConnectivityStateWatcherInterface * watcher)184 void CancelConnectivityStateWatch( 185 ConnectivityStateWatcherInterface* watcher) override 186 ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) { 187 auto it = watcher_map_.find(watcher); 188 if (it == watcher_map_.end()) return; 189 state_->state_tracker_.RemoveWatcher(it->second); 190 watcher_map_.erase(it); 191 } 192 RequestConnection()193 void RequestConnection() override { 194 MutexLock lock(&state_->requested_connection_mu_); 195 state_->requested_connection_ = true; 196 } 197 AddDataWatcher(std::unique_ptr<DataWatcherInterface> watcher)198 void AddDataWatcher( 199 std::unique_ptr<DataWatcherInterface> watcher) override 200 ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) { 201 MutexLock lock(&state_->backend_metric_watcher_mu_); 202 auto* w = 203 static_cast<InternalSubchannelDataWatcherInterface*>(watcher.get()); 204 if (w->type() == OrcaProducer::Type()) { 205 CHECK(orca_watcher_ == nullptr); 206 orca_watcher_.reset(static_cast<OrcaWatcher*>(watcher.release())); 207 state_->orca_watchers_.insert(orca_watcher_.get()); 208 } else if (w->type() == HealthProducer::Type()) { 209 // TODO(roth): Support health checking in test framework. 210 // For now, we just hard-code this to the raw connectivity state. 211 CHECK(health_watcher_ == nullptr); 212 CHECK_EQ(health_watcher_wrapper_, nullptr); 213 health_watcher_.reset(static_cast<HealthWatcher*>(watcher.release())); 214 auto connectivity_watcher = health_watcher_->TakeWatcher(); 215 auto* connectivity_watcher_ptr = connectivity_watcher.get(); 216 auto watcher_wrapper = MakeOrphanable<WatcherWrapper>( 217 state_->work_serializer(), std::move(connectivity_watcher)); 218 health_watcher_wrapper_ = watcher_wrapper.get(); 219 state_->state_tracker_.AddWatcher(GRPC_CHANNEL_SHUTDOWN, 220 std::move(watcher_wrapper)); 221 LOG(INFO) << "AddDataWatcher(): added HealthWatch=" 222 << health_watcher_.get() 223 << " connectivity_watcher=" << connectivity_watcher_ptr 224 << " watcher_wrapper=" << health_watcher_wrapper_; 225 } 226 } 227 CancelDataWatcher(DataWatcherInterface * watcher)228 void CancelDataWatcher(DataWatcherInterface* watcher) override 229 ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) { 230 MutexLock lock(&state_->backend_metric_watcher_mu_); 231 auto* w = static_cast<InternalSubchannelDataWatcherInterface*>(watcher); 232 if (w->type() == OrcaProducer::Type()) { 233 if (orca_watcher_.get() != static_cast<OrcaWatcher*>(watcher)) return; 234 state_->orca_watchers_.erase(orca_watcher_.get()); 235 orca_watcher_.reset(); 236 } else if (w->type() == HealthProducer::Type()) { 237 if (health_watcher_.get() != static_cast<HealthWatcher*>(watcher)) { 238 return; 239 } 240 LOG(INFO) << "CancelDataWatcher(): cancelling HealthWatch=" 241 << health_watcher_.get() 242 << " watcher_wrapper=" << health_watcher_wrapper_; 243 state_->state_tracker_.RemoveWatcher(health_watcher_wrapper_); 244 health_watcher_wrapper_ = nullptr; 245 health_watcher_.reset(); 246 } 247 } 248 249 // Don't need this method, so it's a no-op. ResetBackoff()250 void ResetBackoff() override {} 251 252 SubchannelState* state_; 253 std::map<SubchannelInterface::ConnectivityStateWatcherInterface*, 254 WatcherWrapper*> 255 watcher_map_; 256 std::unique_ptr<HealthWatcher> health_watcher_; 257 WatcherWrapper* health_watcher_wrapper_ = nullptr; 258 std::unique_ptr<OrcaWatcher> orca_watcher_; 259 }; 260 SubchannelState(absl::string_view address,LoadBalancingPolicyTest * test)261 SubchannelState(absl::string_view address, LoadBalancingPolicyTest* test) 262 : address_(address), 263 test_(test), 264 state_tracker_("LoadBalancingPolicyTest") {} 265 address()266 const std::string& address() const { return address_; } 267 268 void AssertValidConnectivityStateTransition( 269 grpc_connectivity_state from_state, grpc_connectivity_state to_state, 270 SourceLocation location = SourceLocation()) { 271 switch (from_state) { 272 case GRPC_CHANNEL_IDLE: 273 ASSERT_EQ(to_state, GRPC_CHANNEL_CONNECTING) 274 << ConnectivityStateName(from_state) << "=>" 275 << ConnectivityStateName(to_state) << "\n" 276 << location.file() << ":" << location.line(); 277 break; 278 case GRPC_CHANNEL_CONNECTING: 279 ASSERT_THAT(to_state, 280 ::testing::AnyOf(GRPC_CHANNEL_READY, 281 GRPC_CHANNEL_TRANSIENT_FAILURE)) 282 << ConnectivityStateName(from_state) << "=>" 283 << ConnectivityStateName(to_state) << "\n" 284 << location.file() << ":" << location.line(); 285 break; 286 case GRPC_CHANNEL_READY: 287 ASSERT_EQ(to_state, GRPC_CHANNEL_IDLE) 288 << ConnectivityStateName(from_state) << "=>" 289 << ConnectivityStateName(to_state) << "\n" 290 << location.file() << ":" << location.line(); 291 break; 292 case GRPC_CHANNEL_TRANSIENT_FAILURE: 293 ASSERT_EQ(to_state, GRPC_CHANNEL_IDLE) 294 << ConnectivityStateName(from_state) << "=>" 295 << ConnectivityStateName(to_state) << "\n" 296 << location.file() << ":" << location.line(); 297 break; 298 default: 299 FAIL() << ConnectivityStateName(from_state) << "=>" 300 << ConnectivityStateName(to_state) << "\n" 301 << location.file() << ":" << location.line(); 302 break; 303 } 304 } 305 306 // Sets the connectivity state for this subchannel. The updated state 307 // will be reported to all associated SubchannelInterface objects. 308 void SetConnectivityState( 309 grpc_connectivity_state state, 310 const absl::Status& status = absl::OkStatus(), 311 bool validate_state_transition = true, 312 absl::AnyInvocable<void()> run_before_flush = nullptr, 313 SourceLocation location = SourceLocation()) { 314 ExecCtx exec_ctx; 315 if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) { 316 EXPECT_FALSE(status.ok()) 317 << "bug in test: TRANSIENT_FAILURE must have non-OK status"; 318 } else { 319 EXPECT_TRUE(status.ok()) 320 << "bug in test: " << ConnectivityStateName(state) 321 << " must have OK status: " << status; 322 } 323 // Updating the state in the state tracker will enqueue 324 // notifications to watchers on the WorkSerializer. If any 325 // subchannel reports READY, the pick_first leaf policy will then 326 // start a health watch, whose initial notification will also be 327 // scheduled on the WorkSerializer. We don't want to return until 328 // all of those notifications have been delivered. 329 absl::Notification notification; 330 test_->work_serializer_->Run( 331 [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(*test_->work_serializer_) { 332 if (validate_state_transition) { 333 AssertValidConnectivityStateTransition(state_tracker_.state(), 334 state, location); 335 } 336 LOG(INFO) << "Setting state on tracker"; 337 state_tracker_.SetState(state, status, "set from test"); 338 // SetState() enqueued the connectivity state notifications for 339 // the subchannel, so we add another callback to the queue to be 340 // executed after that state notifications has been delivered. 341 if (run_before_flush != nullptr) run_before_flush(); 342 LOG(INFO) << "Waiting for state notifications to be delivered"; 343 test_->work_serializer_->Run( 344 [&]() { 345 LOG(INFO) << "State notifications delivered, waiting for " 346 "health notifications"; 347 // Now the connectivity state notifications has been 348 // delivered. If the state reported was READY, then the 349 // pick_first leaf policy will have started a health watch, so 350 // we add another callback to the queue to be executed after 351 // the initial health watch notification has been delivered. 352 test_->work_serializer_->Run([&]() { notification.Notify(); }, 353 DEBUG_LOCATION); 354 }, 355 DEBUG_LOCATION); 356 }, 357 DEBUG_LOCATION); 358 while (!notification.HasBeenNotified()) { 359 test_->fuzzing_ee_->Tick(); 360 } 361 LOG(INFO) << "Health notifications delivered"; 362 } 363 364 // Indicates if any of the associated SubchannelInterface objects 365 // have requested a connection attempt since the last time this 366 // method was called. ConnectionRequested()367 bool ConnectionRequested() { 368 MutexLock lock(&requested_connection_mu_); 369 return std::exchange(requested_connection_, false); 370 } 371 372 // To be invoked by FakeHelper. CreateSubchannel()373 RefCountedPtr<SubchannelInterface> CreateSubchannel() { 374 return MakeRefCounted<FakeSubchannel>(this); 375 } 376 377 // Sends an OOB backend metric report to all watchers. SendOobBackendMetricReport(const BackendMetricData & backend_metrics)378 void SendOobBackendMetricReport(const BackendMetricData& backend_metrics) { 379 MutexLock lock(&backend_metric_watcher_mu_); 380 for (const auto* watcher : orca_watchers_) { 381 watcher->watcher()->OnBackendMetricReport(backend_metrics); 382 } 383 } 384 385 // Checks that all OOB watchers have the expected reporting period. 386 void CheckOobReportingPeriod(Duration expected, 387 SourceLocation location = SourceLocation()) { 388 MutexLock lock(&backend_metric_watcher_mu_); 389 for (const auto* watcher : orca_watchers_) { 390 EXPECT_EQ(watcher->report_interval(), expected) 391 << location.file() << ":" << location.line(); 392 } 393 } 394 NumWatchers()395 size_t NumWatchers() const { 396 size_t num_watchers; 397 absl::Notification notification; 398 work_serializer()->Run( 399 [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(*test_->work_serializer_) { 400 num_watchers = state_tracker_.NumWatchers(); 401 notification.Notify(); 402 }, 403 DEBUG_LOCATION); 404 while (!notification.HasBeenNotified()) { 405 test_->fuzzing_ee_->Tick(); 406 } 407 return num_watchers; 408 } 409 work_serializer()410 std::shared_ptr<WorkSerializer> work_serializer() const { 411 return test_->work_serializer_; 412 } 413 state_tracker()414 ConnectivityStateTracker& state_tracker() { return state_tracker_; } 415 416 private: 417 const std::string address_; 418 LoadBalancingPolicyTest* const test_; 419 ConnectivityStateTracker state_tracker_ 420 ABSL_GUARDED_BY(*test_->work_serializer_); 421 422 Mutex requested_connection_mu_; 423 bool requested_connection_ ABSL_GUARDED_BY(&requested_connection_mu_) = 424 false; 425 426 Mutex backend_metric_watcher_mu_; 427 std::set<OrcaWatcher*> orca_watchers_ 428 ABSL_GUARDED_BY(&backend_metric_watcher_mu_); 429 }; 430 431 // A fake helper to be passed to the LB policy. 432 class FakeHelper : public LoadBalancingPolicy::ChannelControlHelper { 433 public: 434 // Represents a state update reported by the LB policy. 435 struct StateUpdate { 436 grpc_connectivity_state state; 437 absl::Status status; 438 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker; 439 ToStringStateUpdate440 std::string ToString() const { 441 return absl::StrFormat("UPDATE{state=%s, status=%s, picker=%p}", 442 ConnectivityStateName(state), status.ToString(), 443 picker.get()); 444 } 445 }; 446 447 // Represents a re-resolution request from the LB policy. 448 struct ReresolutionRequested { ToStringReresolutionRequested449 std::string ToString() const { return "RERESOLUTION"; } 450 }; 451 FakeHelper(LoadBalancingPolicyTest * test)452 explicit FakeHelper(LoadBalancingPolicyTest* test) : test_(test) {} 453 QueueEmpty()454 bool QueueEmpty() { 455 MutexLock lock(&mu_); 456 return queue_.empty(); 457 } 458 459 // Called at test tear-down time to ensure that we have not left any 460 // unexpected events in the queue. 461 void ExpectQueueEmpty(SourceLocation location = SourceLocation()) { 462 MutexLock lock(&mu_); 463 EXPECT_TRUE(queue_.empty()) 464 << location.file() << ":" << location.line() << "\n" 465 << QueueString(); 466 } 467 468 // Returns the next event in the queue if it is a state update. 469 // If the queue is empty or the next event is not a state update, 470 // fails the test and returns nullopt without removing anything from 471 // the queue. 472 absl::optional<StateUpdate> GetNextStateUpdate( 473 SourceLocation location = SourceLocation()) { 474 MutexLock lock(&mu_); 475 EXPECT_FALSE(queue_.empty()) << location.file() << ":" << location.line(); 476 if (queue_.empty()) return absl::nullopt; 477 Event& event = queue_.front(); 478 auto* update = absl::get_if<StateUpdate>(&event); 479 EXPECT_NE(update, nullptr) 480 << "unexpected event " << EventString(event) << " at " 481 << location.file() << ":" << location.line(); 482 if (update == nullptr) return absl::nullopt; 483 StateUpdate result = std::move(*update); 484 LOG(INFO) << "dequeued next state update: " << result.ToString(); 485 queue_.pop_front(); 486 return std::move(result); 487 } 488 489 // Returns the next event in the queue if it is a re-resolution. 490 // If the queue is empty or the next event is not a re-resolution, 491 // fails the test and returns nullopt without removing anything 492 // from the queue. 493 absl::optional<ReresolutionRequested> GetNextReresolution( 494 SourceLocation location = SourceLocation()) { 495 MutexLock lock(&mu_); 496 EXPECT_FALSE(queue_.empty()) << location.file() << ":" << location.line(); 497 if (queue_.empty()) return absl::nullopt; 498 Event& event = queue_.front(); 499 auto* reresolution = absl::get_if<ReresolutionRequested>(&event); 500 EXPECT_NE(reresolution, nullptr) 501 << "unexpected event " << EventString(event) << " at " 502 << location.file() << ":" << location.line(); 503 if (reresolution == nullptr) return absl::nullopt; 504 ReresolutionRequested result = *reresolution; 505 queue_.pop_front(); 506 return result; 507 } 508 509 private: 510 // A wrapper for a picker that hops into the WorkSerializer to 511 // release the ref to the picker. 512 class PickerWrapper : public LoadBalancingPolicy::SubchannelPicker { 513 public: PickerWrapper(LoadBalancingPolicyTest * test,RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker)514 PickerWrapper(LoadBalancingPolicyTest* test, 515 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker) 516 : test_(test), picker_(std::move(picker)) { 517 LOG(INFO) << "creating wrapper " << this << " for picker " 518 << picker_.get(); 519 } 520 Orphaned()521 void Orphaned() override { 522 absl::Notification notification; 523 ExecCtx exec_ctx; 524 test_->work_serializer_->Run( 525 [notification = ¬ification, 526 picker = std::move(picker_)]() mutable { 527 picker.reset(); 528 notification->Notify(); 529 }, 530 DEBUG_LOCATION); 531 while (!notification.HasBeenNotified()) { 532 test_->fuzzing_ee_->Tick(); 533 } 534 } 535 Pick(LoadBalancingPolicy::PickArgs args)536 LoadBalancingPolicy::PickResult Pick( 537 LoadBalancingPolicy::PickArgs args) override { 538 return picker_->Pick(args); 539 } 540 541 private: 542 LoadBalancingPolicyTest* const test_; 543 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker_; 544 }; 545 546 // Represents an event reported by the LB policy. 547 using Event = absl::variant<StateUpdate, ReresolutionRequested>; 548 549 // Returns a human-readable representation of an event. EventString(const Event & event)550 static std::string EventString(const Event& event) { 551 return Match( 552 event, [](const StateUpdate& update) { return update.ToString(); }, 553 [](const ReresolutionRequested& reresolution) { 554 return reresolution.ToString(); 555 }); 556 } 557 QueueString()558 std::string QueueString() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(&mu_) { 559 std::vector<std::string> parts = {"Queue:"}; 560 for (const Event& event : queue_) { 561 parts.push_back(EventString(event)); 562 } 563 return absl::StrJoin(parts, "\n "); 564 } 565 CreateSubchannel(const grpc_resolved_address & address,const ChannelArgs &,const ChannelArgs & args)566 RefCountedPtr<SubchannelInterface> CreateSubchannel( 567 const grpc_resolved_address& address, 568 const ChannelArgs& /*per_address_args*/, 569 const ChannelArgs& args) override { 570 // TODO(roth): Need to use per_address_args here. 571 SubchannelKey key( 572 address, args.RemoveAllKeysWithPrefix(GRPC_ARG_NO_SUBCHANNEL_PREFIX)); 573 auto it = test_->subchannel_pool_.find(key); 574 if (it == test_->subchannel_pool_.end()) { 575 auto address_uri = grpc_sockaddr_to_uri(&address); 576 CHECK(address_uri.ok()); 577 it = test_->subchannel_pool_ 578 .emplace(std::piecewise_construct, std::forward_as_tuple(key), 579 std::forward_as_tuple(std::move(*address_uri), test_)) 580 .first; 581 } 582 return it->second.CreateSubchannel(); 583 } 584 UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker)585 void UpdateState( 586 grpc_connectivity_state state, const absl::Status& status, 587 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker) override { 588 MutexLock lock(&mu_); 589 StateUpdate update{ 590 state, status, 591 IsWorkSerializerDispatchEnabled() 592 ? std::move(picker) 593 : MakeRefCounted<PickerWrapper>(test_, std::move(picker))}; 594 LOG(INFO) << "enqueuing state update from LB policy: " 595 << update.ToString(); 596 queue_.push_back(std::move(update)); 597 } 598 RequestReresolution()599 void RequestReresolution() override { 600 MutexLock lock(&mu_); 601 queue_.push_back(ReresolutionRequested()); 602 } 603 GetTarget()604 absl::string_view GetTarget() override { return test_->target_; } 605 GetAuthority()606 absl::string_view GetAuthority() override { return test_->authority_; } 607 GetChannelCredentials()608 RefCountedPtr<grpc_channel_credentials> GetChannelCredentials() override { 609 return nullptr; 610 } 611 GetUnsafeChannelCredentials()612 RefCountedPtr<grpc_channel_credentials> GetUnsafeChannelCredentials() 613 override { 614 return nullptr; 615 } 616 GetEventEngine()617 EventEngine* GetEventEngine() override { return test_->fuzzing_ee_.get(); } 618 GetStatsPluginGroup()619 GlobalStatsPluginRegistry::StatsPluginGroup& GetStatsPluginGroup() 620 override { 621 return test_->stats_plugin_group_; 622 } 623 AddTraceEvent(TraceSeverity,absl::string_view)624 void AddTraceEvent(TraceSeverity, absl::string_view) override {} 625 626 LoadBalancingPolicyTest* test_; 627 628 Mutex mu_; 629 std::deque<Event> queue_ ABSL_GUARDED_BY(&mu_); 630 }; 631 632 // A fake MetadataInterface implementation, for use in PickArgs. 633 class FakeMetadata : public LoadBalancingPolicy::MetadataInterface { 634 public: FakeMetadata(std::map<std::string,std::string> metadata)635 explicit FakeMetadata(std::map<std::string, std::string> metadata) 636 : metadata_(std::move(metadata)) {} 637 638 private: Lookup(absl::string_view key,std::string *)639 absl::optional<absl::string_view> Lookup( 640 absl::string_view key, std::string* /*buffer*/) const override { 641 auto it = metadata_.find(std::string(key)); 642 if (it == metadata_.end()) return absl::nullopt; 643 return it->second; 644 } 645 646 std::map<std::string, std::string> metadata_; 647 }; 648 649 // A fake CallState implementation, for use in PickArgs. 650 class FakeCallState : public ClientChannelLbCallState { 651 public: FakeCallState(const CallAttributes & attributes)652 explicit FakeCallState(const CallAttributes& attributes) { 653 for (const auto& attribute : attributes) { 654 attributes_.emplace(attribute->type(), attribute); 655 } 656 } 657 ~FakeCallState()658 ~FakeCallState() override { 659 for (void* allocation : allocations_) { 660 gpr_free(allocation); 661 } 662 } 663 664 private: Alloc(size_t size)665 void* Alloc(size_t size) override { 666 void* allocation = gpr_malloc(size); 667 allocations_.push_back(allocation); 668 return allocation; 669 } 670 GetCallAttribute(UniqueTypeName type)671 ServiceConfigCallData::CallAttributeInterface* GetCallAttribute( 672 UniqueTypeName type) const override { 673 auto it = attributes_.find(type); 674 if (it != attributes_.end()) { 675 return it->second; 676 } 677 return nullptr; 678 } 679 GetCallAttemptTracer()680 ClientCallTracer::CallAttemptTracer* GetCallAttemptTracer() const override { 681 return nullptr; 682 } 683 684 std::vector<void*> allocations_; 685 std::map<UniqueTypeName, ServiceConfigCallData::CallAttributeInterface*> 686 attributes_; 687 }; 688 689 // A fake BackendMetricAccessor implementation, for passing to 690 // SubchannelCallTrackerInterface::Finish(). 691 class FakeBackendMetricAccessor 692 : public LoadBalancingPolicy::BackendMetricAccessor { 693 public: FakeBackendMetricAccessor(absl::optional<BackendMetricData> backend_metric_data)694 explicit FakeBackendMetricAccessor( 695 absl::optional<BackendMetricData> backend_metric_data) 696 : backend_metric_data_(std::move(backend_metric_data)) {} 697 GetBackendMetricData()698 const BackendMetricData* GetBackendMetricData() override { 699 if (backend_metric_data_.has_value()) return &*backend_metric_data_; 700 return nullptr; 701 } 702 703 private: 704 const absl::optional<BackendMetricData> backend_metric_data_; 705 }; 706 707 explicit LoadBalancingPolicyTest(absl::string_view lb_policy_name, 708 ChannelArgs channel_args = ChannelArgs()) lb_policy_name_(lb_policy_name)709 : lb_policy_name_(lb_policy_name), 710 channel_args_(std::move(channel_args)) {} 711 SetUp()712 void SetUp() override { 713 // Order is important here: Fuzzing EE needs to be created before 714 // grpc_init(). 715 fuzzing_ee_ = std::make_shared<FuzzingEventEngine>( 716 FuzzingEventEngine::Options(), fuzzing_event_engine::Actions()); 717 grpc_timer_manager_set_start_threaded(false); 718 grpc_init(); 719 work_serializer_ = std::make_shared<WorkSerializer>(fuzzing_ee_); 720 auto helper = std::make_unique<FakeHelper>(this); 721 helper_ = helper.get(); 722 LoadBalancingPolicy::Args args = {work_serializer_, std::move(helper), 723 channel_args_}; 724 lb_policy_ = 725 CoreConfiguration::Get().lb_policy_registry().CreateLoadBalancingPolicy( 726 lb_policy_name_, std::move(args)); 727 CHECK(lb_policy_ != nullptr); 728 } 729 TearDown()730 void TearDown() override { 731 ExecCtx exec_ctx; 732 fuzzing_ee_->FuzzingDone(); 733 // Make sure pickers (and transitively, subchannels) are unreffed before 734 // destroying the fixture. 735 WaitForWorkSerializerToFlush(); 736 work_serializer_.reset(); 737 exec_ctx.Flush(); 738 if (lb_policy_ != nullptr) { 739 // Note: Can't safely trigger this from inside the FakeHelper dtor, 740 // because if there is a picker in the queue that is holding a ref 741 // to the LB policy, that will prevent the LB policy from being 742 // destroyed, and therefore the FakeHelper will not be destroyed. 743 // (This will cause an ASAN failure, but it will not display the 744 // queued events, so the failure will be harder to diagnose.) 745 helper_->ExpectQueueEmpty(); 746 lb_policy_.reset(); 747 } 748 fuzzing_ee_->TickUntilIdle(); 749 grpc_event_engine::experimental::WaitForSingleOwner(std::move(fuzzing_ee_)); 750 grpc_shutdown_blocking(); 751 } 752 lb_policy()753 LoadBalancingPolicy* lb_policy() const { 754 CHECK(lb_policy_ != nullptr); 755 return lb_policy_.get(); 756 } 757 758 // Creates an LB policy config from json. 759 static RefCountedPtr<LoadBalancingPolicy::Config> MakeConfig( 760 const Json& json, SourceLocation location = SourceLocation()) { 761 auto status_or_config = 762 CoreConfiguration::Get().lb_policy_registry().ParseLoadBalancingConfig( 763 json); 764 EXPECT_TRUE(status_or_config.ok()) 765 << status_or_config.status() << "\n" 766 << location.file() << ":" << location.line(); 767 return status_or_config.value(); 768 } 769 770 // Converts an address URI into a grpc_resolved_address. MakeAddress(absl::string_view address_uri)771 static grpc_resolved_address MakeAddress(absl::string_view address_uri) { 772 auto uri = URI::Parse(address_uri); 773 CHECK(uri.ok()); 774 grpc_resolved_address address; 775 CHECK(grpc_parse_uri(*uri, &address)); 776 return address; 777 } 778 MakeAddressList(absl::Span<const absl::string_view> addresses)779 std::vector<grpc_resolved_address> MakeAddressList( 780 absl::Span<const absl::string_view> addresses) { 781 std::vector<grpc_resolved_address> addrs; 782 for (const absl::string_view& address : addresses) { 783 addrs.emplace_back(MakeAddress(address)); 784 } 785 return addrs; 786 } 787 788 EndpointAddresses MakeEndpointAddresses( 789 absl::Span<const absl::string_view> addresses, 790 const ChannelArgs& args = ChannelArgs()) { 791 return EndpointAddresses(MakeAddressList(addresses), args); 792 } 793 794 // Constructs an update containing a list of endpoints. 795 LoadBalancingPolicy::UpdateArgs BuildUpdate( 796 absl::Span<const EndpointAddresses> endpoints, 797 RefCountedPtr<LoadBalancingPolicy::Config> config, 798 ChannelArgs args = ChannelArgs()) { 799 LoadBalancingPolicy::UpdateArgs update; 800 update.addresses = std::make_shared<EndpointAddressesListIterator>( 801 EndpointAddressesList(endpoints.begin(), endpoints.end())); 802 update.config = std::move(config); 803 update.args = std::move(args); 804 return update; 805 } 806 MakeEndpointAddressesListFromAddressList(absl::Span<const absl::string_view> addresses)807 std::vector<EndpointAddresses> MakeEndpointAddressesListFromAddressList( 808 absl::Span<const absl::string_view> addresses) { 809 std::vector<EndpointAddresses> endpoints; 810 for (const absl::string_view address : addresses) { 811 endpoints.emplace_back(MakeAddress(address), ChannelArgs()); 812 } 813 return endpoints; 814 } 815 816 // Convenient overload that takes a flat address list. 817 LoadBalancingPolicy::UpdateArgs BuildUpdate( 818 absl::Span<const absl::string_view> addresses, 819 RefCountedPtr<LoadBalancingPolicy::Config> config, 820 ChannelArgs args = ChannelArgs()) { 821 return BuildUpdate(MakeEndpointAddressesListFromAddressList(addresses), 822 std::move(config), std::move(args)); 823 } 824 825 // Applies the update on the LB policy. ApplyUpdate(LoadBalancingPolicy::UpdateArgs update_args,LoadBalancingPolicy * lb_policy)826 absl::Status ApplyUpdate(LoadBalancingPolicy::UpdateArgs update_args, 827 LoadBalancingPolicy* lb_policy) { 828 ExecCtx exec_ctx; 829 absl::Status status; 830 // When the LB policy gets the update, it will create new 831 // subchannels, and it will register connectivity state watchers and 832 // optionally health watchers for each one. We don't want to return 833 // until all the initial notifications for all of those watchers 834 // have been delivered to the LB policy. 835 absl::Notification notification; 836 work_serializer_->Run( 837 [&]() { 838 status = lb_policy->UpdateLocked(std::move(update_args)); 839 // UpdateLocked() enqueued the initial connectivity state 840 // notifications for the subchannels, so we add another 841 // callback to the queue to be executed after those initial 842 // state notifications have been delivered. 843 LOG(INFO) << "Applied update, waiting for initial connectivity state " 844 "notifications"; 845 work_serializer_->Run( 846 [&]() { 847 LOG(INFO) << "Initial connectivity state notifications " 848 "delivered; waiting for health notifications"; 849 // Now that the initial state notifications have been 850 // delivered, the queue will contain the health watch 851 // notifications for any subchannels in state READY, 852 // so we add another callback to the queue to be 853 // executed after those health watch notifications have 854 // been delivered. 855 work_serializer_->Run([&]() { notification.Notify(); }, 856 DEBUG_LOCATION); 857 }, 858 DEBUG_LOCATION); 859 }, 860 DEBUG_LOCATION); 861 while (!notification.HasBeenNotified()) { 862 fuzzing_ee_->Tick(); 863 } 864 LOG(INFO) << "health notifications delivered"; 865 return status; 866 } 867 868 // Invoke ExitIdle on the LB policy ExitIdle()869 void ExitIdle() { 870 ExecCtx exec_ctx; 871 absl::Notification notification; 872 // Note: ExitIdle() will enqueue a bunch of connectivity state 873 // notifications on the WorkSerializer, and we want to wait until 874 // those are delivered to the LB policy. 875 work_serializer_->Run( 876 [&]() { 877 lb_policy_->ExitIdleLocked(); 878 work_serializer_->Run([&]() { notification.Notify(); }, 879 DEBUG_LOCATION); 880 }, 881 DEBUG_LOCATION); 882 while (!notification.HasBeenNotified()) { 883 fuzzing_ee_->Tick(); 884 } 885 } 886 887 void ExpectQueueEmpty(SourceLocation location = SourceLocation()) { 888 helper_->ExpectQueueEmpty(location); 889 } 890 891 // Keeps reading state updates until continue_predicate() returns false. 892 // Returns false if the helper reports no events or if the event is 893 // not a state update; otherwise (if continue_predicate() tells us to 894 // stop) returns true. 895 bool WaitForStateUpdate( 896 std::function<bool(FakeHelper::StateUpdate update)> continue_predicate, 897 SourceLocation location = SourceLocation()) { 898 LOG(INFO) << "==> WaitForStateUpdate()"; 899 while (true) { 900 auto update = helper_->GetNextStateUpdate(location); 901 if (!update.has_value()) { 902 LOG(INFO) << "WaitForStateUpdate() returning false"; 903 return false; 904 } 905 if (!continue_predicate(std::move(*update))) { 906 LOG(INFO) << "WaitForStateUpdate() returning true"; 907 return true; 908 } 909 } 910 } 911 912 void ExpectReresolutionRequest(SourceLocation location = SourceLocation()) { 913 ASSERT_TRUE(helper_->GetNextReresolution(location)) 914 << location.file() << ":" << location.line(); 915 } 916 917 // Expects that the LB policy has reported the specified connectivity 918 // state to helper_. Returns the picker from the state update. 919 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> ExpectState( 920 grpc_connectivity_state expected_state, 921 absl::Status expected_status = absl::OkStatus(), 922 SourceLocation location = SourceLocation()) { 923 auto update = helper_->GetNextStateUpdate(location); 924 if (!update.has_value()) return nullptr; 925 EXPECT_EQ(update->state, expected_state) 926 << "got " << ConnectivityStateName(update->state) << ", expected " 927 << ConnectivityStateName(expected_state) << "\n" 928 << "at " << location.file() << ":" << location.line(); 929 EXPECT_EQ(update->status, expected_status) 930 << update->status << "\n" 931 << location.file() << ":" << location.line(); 932 EXPECT_NE(update->picker, nullptr) 933 << location.file() << ":" << location.line(); 934 return std::move(update->picker); 935 } 936 937 // Waits for the LB policy to get connected, then returns the final 938 // picker. There can be any number of CONNECTING updates, each of 939 // which must return a picker that queues picks, followed by one 940 // update for state READY, whose picker is returned. 941 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> WaitForConnected( 942 SourceLocation location = SourceLocation()) { 943 LOG(INFO) << "==> WaitForConnected()"; 944 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> final_picker; 945 WaitForStateUpdate( 946 [&](FakeHelper::StateUpdate update) { 947 if (update.state == GRPC_CHANNEL_CONNECTING) { 948 EXPECT_TRUE(update.status.ok()) 949 << update.status << " at " << location.file() << ":" 950 << location.line(); 951 ExpectPickQueued(update.picker.get(), {}, {}, location); 952 return true; // Keep going. 953 } 954 EXPECT_EQ(update.state, GRPC_CHANNEL_READY) 955 << ConnectivityStateName(update.state) << " at " 956 << location.file() << ":" << location.line(); 957 final_picker = std::move(update.picker); 958 return false; // Stop. 959 }, 960 location); 961 return final_picker; 962 } 963 964 void ExpectTransientFailureUpdate( 965 absl::Status expected_status, 966 SourceLocation location = SourceLocation()) { 967 auto picker = 968 ExpectState(GRPC_CHANNEL_TRANSIENT_FAILURE, expected_status, location); 969 ASSERT_NE(picker, nullptr); 970 ExpectPickFail( 971 picker.get(), 972 [&](const absl::Status& status) { 973 EXPECT_EQ(status, expected_status) 974 << location.file() << ":" << location.line(); 975 }, 976 location); 977 } 978 979 // Waits for the LB policy to fail a connection attempt. There can be 980 // any number of CONNECTING updates, each of which must return a picker 981 // that queues picks, followed by one update for state TRANSIENT_FAILURE, 982 // whose status is passed to check_status() and whose picker must fail 983 // picks with a status that is passed to check_status(). 984 // Returns true if the reported states match expectations. 985 bool WaitForConnectionFailed( 986 std::function<void(const absl::Status&)> check_status, 987 SourceLocation location = SourceLocation()) { 988 bool retval = false; 989 WaitForStateUpdate( 990 [&](FakeHelper::StateUpdate update) { 991 if (update.state == GRPC_CHANNEL_CONNECTING) { 992 EXPECT_TRUE(update.status.ok()) 993 << update.status << " at " << location.file() << ":" 994 << location.line(); 995 ExpectPickQueued(update.picker.get(), {}, {}, location); 996 return true; // Keep going. 997 } 998 EXPECT_EQ(update.state, GRPC_CHANNEL_TRANSIENT_FAILURE) 999 << ConnectivityStateName(update.state) << " at " 1000 << location.file() << ":" << location.line(); 1001 check_status(update.status); 1002 ExpectPickFail(update.picker.get(), check_status, location); 1003 retval = update.state == GRPC_CHANNEL_TRANSIENT_FAILURE; 1004 return false; // Stop. 1005 }, 1006 location); 1007 return retval; 1008 } 1009 1010 // Waits for the round_robin policy to start using an updated address list. 1011 // There can be any number of READY updates where the picker is still using 1012 // the old list followed by one READY update where the picker is using the 1013 // new list. Returns a picker if the reported states match expectations. 1014 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> 1015 WaitForRoundRobinListChange(absl::Span<const absl::string_view> old_addresses, 1016 absl::Span<const absl::string_view> new_addresses, 1017 const CallAttributes& call_attributes = {}, 1018 size_t num_iterations = 3, 1019 SourceLocation location = SourceLocation()) { 1020 LOG(INFO) << "Waiting for expected RR addresses..."; 1021 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> retval; 1022 size_t num_picks = 1023 std::max(new_addresses.size(), old_addresses.size()) * num_iterations; 1024 WaitForStateUpdate( 1025 [&](FakeHelper::StateUpdate update) { 1026 EXPECT_EQ(update.state, GRPC_CHANNEL_READY) 1027 << location.file() << ":" << location.line(); 1028 if (update.state != GRPC_CHANNEL_READY) return false; 1029 // Get enough picks to round-robin num_iterations times across all 1030 // expected addresses. 1031 auto picks = GetCompletePicks(update.picker.get(), num_picks, 1032 call_attributes, nullptr, location); 1033 EXPECT_TRUE(picks.has_value()) 1034 << location.file() << ":" << location.line(); 1035 if (!picks.has_value()) return false; 1036 LOG(INFO) << "PICKS: " << absl::StrJoin(*picks, " "); 1037 // If the picks still match the old list, then keep going. 1038 if (PicksAreRoundRobin(old_addresses, *picks)) return true; 1039 // Otherwise, the picks should match the new list. 1040 bool matches = PicksAreRoundRobin(new_addresses, *picks); 1041 EXPECT_TRUE(matches) 1042 << "Expected: " << absl::StrJoin(new_addresses, ", ") 1043 << "\nActual: " << absl::StrJoin(*picks, ", ") << "\nat " 1044 << location.file() << ":" << location.line(); 1045 if (matches) { 1046 retval = std::move(update.picker); 1047 } 1048 return false; // Stop. 1049 }, 1050 location); 1051 LOG(INFO) << "done waiting for expected RR addresses"; 1052 return retval; 1053 } 1054 1055 // Expects a state update for the specified state and status, and then 1056 // expects the resulting picker to queue picks. 1057 bool ExpectStateAndQueuingPicker( 1058 grpc_connectivity_state expected_state, 1059 absl::Status expected_status = absl::OkStatus(), 1060 SourceLocation location = SourceLocation()) { 1061 auto picker = ExpectState(expected_state, expected_status, location); 1062 return ExpectPickQueued(picker.get(), {}, {}, location); 1063 } 1064 1065 // Convenient frontend to ExpectStateAndQueuingPicker() for CONNECTING. 1066 bool ExpectConnectingUpdate(SourceLocation location = SourceLocation()) { 1067 return ExpectStateAndQueuingPicker(GRPC_CHANNEL_CONNECTING, 1068 absl::OkStatus(), location); 1069 } 1070 1071 static std::unique_ptr<LoadBalancingPolicy::MetadataInterface> MakeMetadata( 1072 std::map<std::string, std::string> init = {}) { 1073 return std::make_unique<FakeMetadata>(init); 1074 } 1075 1076 // Does a pick and returns the result. 1077 LoadBalancingPolicy::PickResult DoPick( 1078 LoadBalancingPolicy::SubchannelPicker* picker, 1079 const CallAttributes& call_attributes = {}, 1080 const std::map<std::string, std::string>& metadata = {}) { 1081 ExecCtx exec_ctx; 1082 FakeMetadata md(metadata); 1083 FakeCallState call_state(call_attributes); 1084 return picker->Pick({"/service/method", &md, &call_state}); 1085 } 1086 1087 // Requests a pick on picker and expects a Queue result. 1088 bool ExpectPickQueued(LoadBalancingPolicy::SubchannelPicker* picker, 1089 const CallAttributes call_attributes = {}, 1090 const std::map<std::string, std::string>& metadata = {}, 1091 SourceLocation location = SourceLocation()) { 1092 EXPECT_NE(picker, nullptr) << location.file() << ":" << location.line(); 1093 if (picker == nullptr) return false; 1094 auto pick_result = DoPick(picker, call_attributes, metadata); 1095 EXPECT_TRUE(absl::holds_alternative<LoadBalancingPolicy::PickResult::Queue>( 1096 pick_result.result)) 1097 << PickResultString(pick_result) << "\nat " << location.file() << ":" 1098 << location.line(); 1099 return absl::holds_alternative<LoadBalancingPolicy::PickResult::Queue>( 1100 pick_result.result); 1101 } 1102 1103 // Requests a pick on picker and expects a Complete result. 1104 // The address of the resulting subchannel is returned, or nullopt if 1105 // the result was something other than Complete. 1106 // If the complete pick includes a SubchannelCallTrackerInterface, then if 1107 // subchannel_call_tracker is non-null, it will be set to point to the 1108 // call tracker; otherwise, the call tracker will be invoked 1109 // automatically to represent a complete call with no backend metric data. 1110 absl::optional<std::string> ExpectPickComplete( 1111 LoadBalancingPolicy::SubchannelPicker* picker, 1112 const CallAttributes& call_attributes = {}, 1113 const std::map<std::string, std::string>& metadata = {}, 1114 std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>* 1115 subchannel_call_tracker = nullptr, 1116 SubchannelState::FakeSubchannel** picked_subchannel = nullptr, 1117 SourceLocation location = SourceLocation()) { 1118 EXPECT_NE(picker, nullptr); 1119 if (picker == nullptr) { 1120 return absl::nullopt; 1121 } 1122 auto pick_result = DoPick(picker, call_attributes, metadata); 1123 auto* complete = absl::get_if<LoadBalancingPolicy::PickResult::Complete>( 1124 &pick_result.result); 1125 EXPECT_NE(complete, nullptr) << PickResultString(pick_result) << " at " 1126 << location.file() << ":" << location.line(); 1127 if (complete == nullptr) return absl::nullopt; 1128 auto* subchannel = static_cast<SubchannelState::FakeSubchannel*>( 1129 complete->subchannel.get()); 1130 if (picked_subchannel != nullptr) *picked_subchannel = subchannel; 1131 std::string address = subchannel->state()->address(); 1132 if (complete->subchannel_call_tracker != nullptr) { 1133 if (subchannel_call_tracker != nullptr) { 1134 *subchannel_call_tracker = std::move(complete->subchannel_call_tracker); 1135 } else { 1136 ReportCompletionToCallTracker( 1137 std::move(complete->subchannel_call_tracker), address); 1138 } 1139 } 1140 return address; 1141 } 1142 1143 void ReportCompletionToCallTracker( 1144 std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface> 1145 subchannel_call_tracker, 1146 absl::string_view address, absl::Status status = absl::OkStatus()) { 1147 subchannel_call_tracker->Start(); 1148 FakeMetadata metadata({}); 1149 FakeBackendMetricAccessor backend_metric_accessor({}); 1150 LoadBalancingPolicy::SubchannelCallTrackerInterface::FinishArgs args = { 1151 address, status, &metadata, &backend_metric_accessor}; 1152 subchannel_call_tracker->Finish(args); 1153 } 1154 1155 // Gets num_picks complete picks from picker and returns the resulting 1156 // list of addresses, or nullopt if a non-complete pick was returned. 1157 absl::optional<std::vector<std::string>> GetCompletePicks( 1158 LoadBalancingPolicy::SubchannelPicker* picker, size_t num_picks, 1159 const CallAttributes& call_attributes = {}, 1160 std::vector< 1161 std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>>* 1162 subchannel_call_trackers = nullptr, 1163 SourceLocation location = SourceLocation()) { 1164 EXPECT_NE(picker, nullptr); 1165 if (picker == nullptr) { 1166 return absl::nullopt; 1167 } 1168 std::vector<std::string> results; 1169 for (size_t i = 0; i < num_picks; ++i) { 1170 std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface> 1171 subchannel_call_tracker; 1172 auto address = ExpectPickComplete(picker, call_attributes, 1173 /*metadata=*/{}, 1174 subchannel_call_trackers == nullptr 1175 ? nullptr 1176 : &subchannel_call_tracker, 1177 nullptr, location); 1178 if (!address.has_value()) return absl::nullopt; 1179 results.emplace_back(std::move(*address)); 1180 if (subchannel_call_trackers != nullptr) { 1181 subchannel_call_trackers->emplace_back( 1182 std::move(subchannel_call_tracker)); 1183 } 1184 } 1185 return results; 1186 } 1187 1188 // Returns true if the list of actual pick result addresses matches the 1189 // list of expected addresses for round_robin. Note that the actual 1190 // addresses may start anywhere in the list of expected addresses but 1191 // must then continue in round-robin fashion, with wrap-around. PicksAreRoundRobin(absl::Span<const absl::string_view> expected,absl::Span<const std::string> actual)1192 bool PicksAreRoundRobin(absl::Span<const absl::string_view> expected, 1193 absl::Span<const std::string> actual) { 1194 absl::optional<size_t> expected_index; 1195 for (const auto& address : actual) { 1196 auto it = std::find(expected.begin(), expected.end(), address); 1197 if (it == expected.end()) return false; 1198 size_t index = it - expected.begin(); 1199 if (expected_index.has_value() && index != *expected_index) return false; 1200 expected_index = (index + 1) % expected.size(); 1201 } 1202 return true; 1203 } 1204 1205 // Checks that the picker has round-robin behavior over the specified 1206 // set of addresses. 1207 void ExpectRoundRobinPicks(LoadBalancingPolicy::SubchannelPicker* picker, 1208 absl::Span<const absl::string_view> addresses, 1209 const CallAttributes& call_attributes = {}, 1210 size_t num_iterations = 3, 1211 SourceLocation location = SourceLocation()) { 1212 auto picks = GetCompletePicks(picker, num_iterations * addresses.size(), 1213 call_attributes, nullptr, location); 1214 ASSERT_TRUE(picks.has_value()) << location.file() << ":" << location.line(); 1215 EXPECT_TRUE(PicksAreRoundRobin(addresses, *picks)) 1216 << " Actual: " << absl::StrJoin(*picks, ", ") 1217 << "\n Expected: " << absl::StrJoin(addresses, ", ") << "\n" 1218 << location.file() << ":" << location.line(); 1219 } 1220 1221 // Expect startup with RR with a set of addresses. 1222 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> ExpectRoundRobinStartup( 1223 absl::Span<const EndpointAddresses> endpoints, 1224 SourceLocation location = SourceLocation()) { 1225 CHECK(!endpoints.empty()); 1226 // There should be a subchannel for every address. 1227 // We will wind up connecting to the first address for every endpoint. 1228 std::vector<std::vector<SubchannelState*>> endpoint_subchannels; 1229 endpoint_subchannels.reserve(endpoints.size()); 1230 std::vector<std::string> chosen_addresses_storage; 1231 chosen_addresses_storage.reserve(endpoints.size()); 1232 std::vector<absl::string_view> chosen_addresses; 1233 chosen_addresses.reserve(endpoints.size()); 1234 for (const EndpointAddresses& endpoint : endpoints) { 1235 endpoint_subchannels.emplace_back(); 1236 endpoint_subchannels.back().reserve(endpoint.addresses().size()); 1237 for (size_t i = 0; i < endpoint.addresses().size(); ++i) { 1238 const grpc_resolved_address& address = endpoint.addresses()[i]; 1239 std::string address_str = grpc_sockaddr_to_uri(&address).value(); 1240 auto* subchannel = FindSubchannel(address_str); 1241 EXPECT_NE(subchannel, nullptr) 1242 << address_str << "\n" 1243 << location.file() << ":" << location.line(); 1244 if (subchannel == nullptr) return nullptr; 1245 endpoint_subchannels.back().push_back(subchannel); 1246 if (i == 0) { 1247 chosen_addresses_storage.emplace_back(std::move(address_str)); 1248 chosen_addresses.emplace_back(chosen_addresses_storage.back()); 1249 } 1250 } 1251 } 1252 // We should request a connection to the first address of each endpoint, 1253 // and not to any of the subsequent addresses. 1254 for (const auto& subchannels : endpoint_subchannels) { 1255 EXPECT_TRUE(subchannels[0]->ConnectionRequested()) 1256 << location.file() << ":" << location.line(); 1257 for (size_t i = 1; i < subchannels.size(); ++i) { 1258 EXPECT_FALSE(subchannels[i]->ConnectionRequested()) 1259 << "i=" << i << "\n" 1260 << location.file() << ":" << location.line(); 1261 } 1262 } 1263 // The subchannels that we've asked to connect should report 1264 // CONNECTING state. 1265 for (size_t i = 0; i < endpoint_subchannels.size(); ++i) { 1266 endpoint_subchannels[i][0]->SetConnectivityState(GRPC_CHANNEL_CONNECTING); 1267 if (i == 0) ExpectConnectingUpdate(location); 1268 } 1269 // The connection attempts should succeed. 1270 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker; 1271 for (size_t i = 0; i < endpoint_subchannels.size(); ++i) { 1272 endpoint_subchannels[i][0]->SetConnectivityState(GRPC_CHANNEL_READY); 1273 if (i == 0) { 1274 // When the first subchannel becomes READY, accept any number of 1275 // CONNECTING updates with a picker that queues followed by a READY 1276 // update with a picker that repeatedly returns only the first address. 1277 picker = WaitForConnected(location); 1278 ExpectRoundRobinPicks(picker.get(), {chosen_addresses[0]}, {}, 3, 1279 location); 1280 } else { 1281 // When each subsequent subchannel becomes READY, we accept any number 1282 // of READY updates where the picker returns only the previously 1283 // connected subchannel(s) followed by a READY update where the picker 1284 // returns the previously connected subchannel(s) *and* the newly 1285 // connected subchannel. 1286 picker = WaitForRoundRobinListChange( 1287 absl::MakeSpan(chosen_addresses).subspan(0, i), 1288 absl::MakeSpan(chosen_addresses).subspan(0, i + 1), {}, 3, 1289 location); 1290 } 1291 } 1292 return picker; 1293 } 1294 1295 // A convenient override that takes a flat list of addresses, one per 1296 // endpoint. 1297 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> ExpectRoundRobinStartup( 1298 absl::Span<const absl::string_view> addresses, 1299 SourceLocation location = SourceLocation()) { 1300 return ExpectRoundRobinStartup( 1301 MakeEndpointAddressesListFromAddressList(addresses), location); 1302 } 1303 1304 // Expects zero or more picker updates, each of which returns 1305 // round-robin picks for the specified set of addresses. 1306 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> 1307 DrainRoundRobinPickerUpdates(absl::Span<const absl::string_view> addresses, 1308 SourceLocation location = SourceLocation()) { 1309 LOG(INFO) << "Draining RR picker updates..."; 1310 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker; 1311 while (!helper_->QueueEmpty()) { 1312 auto update = helper_->GetNextStateUpdate(location); 1313 EXPECT_TRUE(update.has_value()) 1314 << location.file() << ":" << location.line(); 1315 if (!update.has_value()) return nullptr; 1316 EXPECT_EQ(update->state, GRPC_CHANNEL_READY) 1317 << location.file() << ":" << location.line(); 1318 if (update->state != GRPC_CHANNEL_READY) return nullptr; 1319 ExpectRoundRobinPicks(update->picker.get(), addresses, 1320 /*call_attributes=*/{}, /*num_iterations=*/3, 1321 location); 1322 picker = std::move(update->picker); 1323 } 1324 LOG(INFO) << "Done draining RR picker updates"; 1325 return picker; 1326 } 1327 1328 // Expects zero or more CONNECTING updates. 1329 void DrainConnectingUpdates(SourceLocation location = SourceLocation()) { 1330 LOG(INFO) << "Draining CONNECTING updates..."; 1331 while (!helper_->QueueEmpty()) { 1332 ASSERT_TRUE(ExpectConnectingUpdate(location)); 1333 } 1334 LOG(INFO) << "Done draining CONNECTING updates"; 1335 } 1336 1337 // Triggers a connection failure for the current address for an 1338 // endpoint and expects a reconnection to the specified new address. 1339 void ExpectEndpointAddressChange( 1340 absl::Span<const absl::string_view> addresses, size_t current_index, 1341 size_t new_index, absl::AnyInvocable<void()> expect_after_disconnect, 1342 SourceLocation location = SourceLocation()) { 1343 LOG(INFO) << "Expecting endpoint address change: addresses={" 1344 << absl::StrJoin(addresses, ", ") 1345 << "}, current_index=" << current_index 1346 << ", new_index=" << new_index; 1347 ASSERT_LT(current_index, addresses.size()); 1348 ASSERT_LT(new_index, addresses.size()); 1349 // Find all subchannels. 1350 std::vector<SubchannelState*> subchannels; 1351 subchannels.reserve(addresses.size()); 1352 for (absl::string_view address : addresses) { 1353 SubchannelState* subchannel = FindSubchannel(address); 1354 ASSERT_NE(subchannel, nullptr) 1355 << "can't find subchannel for " << address << "\n" 1356 << location.file() << ":" << location.line(); 1357 subchannels.push_back(subchannel); 1358 } 1359 // Cause current_address to become disconnected. 1360 subchannels[current_index]->SetConnectivityState(GRPC_CHANNEL_IDLE); 1361 ExpectReresolutionRequest(location); 1362 if (expect_after_disconnect != nullptr) expect_after_disconnect(); 1363 // Attempt each address in the list until we hit the desired new address. 1364 for (size_t i = 0; i < subchannels.size(); ++i) { 1365 // A connection should be requested on the subchannel for this 1366 // index, and none of the others. 1367 for (size_t j = 0; j < addresses.size(); ++j) { 1368 EXPECT_EQ(subchannels[j]->ConnectionRequested(), j == i) 1369 << location.file() << ":" << location.line(); 1370 } 1371 // Subchannel will report CONNECTING. 1372 SubchannelState* subchannel = subchannels[i]; 1373 subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); 1374 // If this is the one we want to stick with, it will report READY. 1375 if (i == new_index) { 1376 subchannel->SetConnectivityState(GRPC_CHANNEL_READY); 1377 break; 1378 } 1379 // Otherwise, report TF. 1380 subchannel->SetConnectivityState( 1381 GRPC_CHANNEL_TRANSIENT_FAILURE, 1382 absl::UnavailableError("connection failed")); 1383 // Report IDLE to leave it in the expected state in case the test 1384 // interacts with it again. 1385 subchannel->SetConnectivityState(GRPC_CHANNEL_IDLE); 1386 } 1387 LOG(INFO) << "Done with endpoint address change"; 1388 } 1389 1390 // Requests a picker on picker and expects a Fail result. 1391 // The failing status is passed to check_status. 1392 void ExpectPickFail(LoadBalancingPolicy::SubchannelPicker* picker, 1393 std::function<void(const absl::Status&)> check_status, 1394 SourceLocation location = SourceLocation()) { 1395 auto pick_result = DoPick(picker); 1396 auto* fail = absl::get_if<LoadBalancingPolicy::PickResult::Fail>( 1397 &pick_result.result); 1398 ASSERT_NE(fail, nullptr) << PickResultString(pick_result) << " at " 1399 << location.file() << ":" << location.line(); 1400 check_status(fail->status); 1401 } 1402 1403 // Returns a human-readable string for a pick result. PickResultString(const LoadBalancingPolicy::PickResult & result)1404 static std::string PickResultString( 1405 const LoadBalancingPolicy::PickResult& result) { 1406 return Match( 1407 result.result, 1408 [](const LoadBalancingPolicy::PickResult::Complete& complete) { 1409 auto* subchannel = static_cast<SubchannelState::FakeSubchannel*>( 1410 complete.subchannel.get()); 1411 return absl::StrFormat( 1412 "COMPLETE{subchannel=%s, subchannel_call_tracker=%p}", 1413 subchannel->state()->address(), 1414 complete.subchannel_call_tracker.get()); 1415 }, 1416 [](const LoadBalancingPolicy::PickResult::Queue&) -> std::string { 1417 return "QUEUE{}"; 1418 }, 1419 [](const LoadBalancingPolicy::PickResult::Fail& fail) -> std::string { 1420 return absl::StrFormat("FAIL{%s}", fail.status.ToString()); 1421 }, 1422 [](const LoadBalancingPolicy::PickResult::Drop& drop) -> std::string { 1423 return absl::StrFormat("FAIL{%s}", drop.status.ToString()); 1424 }); 1425 } 1426 1427 // Returns the entry in the subchannel pool, or null if not present. 1428 SubchannelState* FindSubchannel(absl::string_view address, 1429 const ChannelArgs& args = ChannelArgs()) { 1430 SubchannelKey key(MakeAddress(address), args); 1431 auto it = subchannel_pool_.find(key); 1432 if (it == subchannel_pool_.end()) return nullptr; 1433 return &it->second; 1434 } 1435 1436 // Creates and returns an entry in the subchannel pool. 1437 // This can be used in cases where we want to test that a subchannel 1438 // already exists when the LB policy creates it (e.g., due to it being 1439 // created by another channel and shared via the global subchannel 1440 // pool, or by being created by another LB policy in this channel). 1441 SubchannelState* CreateSubchannel(absl::string_view address, 1442 const ChannelArgs& args = ChannelArgs()) { 1443 SubchannelKey key(MakeAddress(address), args); 1444 auto it = subchannel_pool_ 1445 .emplace(std::piecewise_construct, std::forward_as_tuple(key), 1446 std::forward_as_tuple(address, this)) 1447 .first; 1448 return &it->second; 1449 } 1450 WaitForWorkSerializerToFlush()1451 void WaitForWorkSerializerToFlush() { 1452 ExecCtx exec_ctx; 1453 LOG(INFO) << "waiting for WorkSerializer to flush..."; 1454 absl::Notification notification; 1455 work_serializer_->Run([&]() { notification.Notify(); }, DEBUG_LOCATION); 1456 while (!notification.HasBeenNotified()) { 1457 fuzzing_ee_->Tick(); 1458 } 1459 LOG(INFO) << "WorkSerializer flush complete"; 1460 } 1461 1462 void IncrementTimeBy(Duration duration, bool flush_work_serializer = true) { 1463 ExecCtx exec_ctx; 1464 LOG(INFO) << "Incrementing time by " << duration; 1465 fuzzing_ee_->TickForDuration(duration); 1466 LOG(INFO) << "Done incrementing time"; 1467 // Flush WorkSerializer, in case the timer callback enqueued anything. 1468 if (flush_work_serializer) WaitForWorkSerializerToFlush(); 1469 } 1470 1471 void SetExpectedTimerDuration(absl::optional<EventEngine::Duration> duration, 1472 SourceLocation location = SourceLocation()) { 1473 if (duration.has_value()) { 1474 fuzzing_ee_->SetRunAfterDurationCallback( 1475 [expected = *duration, 1476 location = location](EventEngine::Duration duration) { 1477 EXPECT_EQ(duration, expected) 1478 << "Expected: " << expected.count() 1479 << "ns\n Actual: " << duration.count() << "ns\n" 1480 << location.file() << ":" << location.line(); 1481 }); 1482 } else { 1483 fuzzing_ee_->SetRunAfterDurationCallback(nullptr); 1484 } 1485 } 1486 1487 std::shared_ptr<FuzzingEventEngine> fuzzing_ee_; 1488 std::shared_ptr<WorkSerializer> work_serializer_; 1489 FakeHelper* helper_ = nullptr; 1490 std::map<SubchannelKey, SubchannelState> subchannel_pool_; 1491 OrphanablePtr<LoadBalancingPolicy> lb_policy_; 1492 const absl::string_view lb_policy_name_; 1493 const ChannelArgs channel_args_; 1494 GlobalStatsPluginRegistry::StatsPluginGroup stats_plugin_group_; 1495 std::string target_ = "dns:server.example.com"; 1496 std::string authority_ = "server.example.com"; 1497 }; 1498 1499 } // namespace testing 1500 } // namespace grpc_core 1501 1502 #endif // GRPC_TEST_CORE_LOAD_BALANCING_LB_POLICY_TEST_LIB_H 1503