1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/follower_scaler.h"
18 #include "ps/core/communicator/tcp_communicator.h"
19
20 namespace mindspore {
21 namespace ps {
22 namespace core {
FollowerScaler(AbstractNode * node)23 FollowerScaler::FollowerScaler(AbstractNode *node)
24 : node_(node), scaling_state_(NodeScaleState::kNormal), running_(true) {
25 process_before_scale_out_thread_ = std::thread([&]() {
26 while (running_.load()) {
27 std::unique_lock<std::mutex> lock(scale_out_mtx_);
28 scale_out_cv_.wait(
29 lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kPreparing; });
30 if (!running_.load()) {
31 break;
32 }
33 ProcessBeforeScaleOut();
34 }
35 });
36
37 process_before_scale_in_thread_ = std::thread([&]() {
38 while (running_.load()) {
39 std::unique_lock<std::mutex> lock(scale_in_mtx_);
40 scale_in_cv_.wait(
41 lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kPreparing; });
42 // In scaling in scenario, abstract node will trigger CLUSTER_SCALE_IN_DONE event in the same thread if this node
43 // is the one to be scaled in, so we need to release the lock here to avoid dead lock.
44 lock.unlock();
45 if (!running_.load()) {
46 break;
47 }
48 ProcessBeforeScaleIn();
49 }
50 });
51
52 process_after_scale_out_thread_ = std::thread([&]() {
53 while (running_.load()) {
54 std::unique_lock<std::mutex> lock(scale_out_mtx_);
55 scale_out_cv_.wait(
56 lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kScaling; });
57 if (!running_.load()) {
58 break;
59 }
60 ProcessAfterScaleOut();
61 }
62 });
63
64 process_after_scale_in_thread_ = std::thread([&]() {
65 while (running_.load()) {
66 std::unique_lock<std::mutex> lock(scale_in_mtx_);
67 scale_in_cv_.wait(
68 lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kScaling; });
69 if (!running_.load()) {
70 break;
71 }
72 ProcessAfterScaleIn();
73 }
74 });
75 }
76
~FollowerScaler()77 FollowerScaler::~FollowerScaler() {
78 running_ = false;
79 scale_out_cv_.notify_all();
80 scale_in_cv_.notify_all();
81 if (process_before_scale_out_thread_.joinable()) {
82 process_before_scale_out_thread_.join();
83 }
84 if (process_before_scale_in_thread_.joinable()) {
85 process_before_scale_in_thread_.join();
86 }
87 if (process_after_scale_out_thread_.joinable()) {
88 process_after_scale_out_thread_.join();
89 }
90 if (process_after_scale_in_thread_.joinable()) {
91 process_after_scale_in_thread_.join();
92 }
93 }
94
RegisterScaleEventCallbacks()95 void FollowerScaler::RegisterScaleEventCallbacks() {
96 ready_for_scale_out_event_callback_ = [&]() -> void {
97 // Notify the thread which will call the barriers.
98 std::unique_lock<std::mutex> lock(scale_out_mtx_);
99 scaling_state_ = NodeScaleState::kPreparing;
100 scale_out_cv_.notify_all();
101 };
102
103 ready_for_scale_in_event_callback_ = [&]() -> void {
104 std::unique_lock<std::mutex> lock(scale_in_mtx_);
105 scaling_state_ = NodeScaleState::kPreparing;
106 scale_in_cv_.notify_all();
107 };
108
109 scale_out_done_event_callback_ = [&]() -> void {
110 std::unique_lock<std::mutex> lock(scale_out_mtx_);
111 scaling_state_ = NodeScaleState::kScaling;
112 scale_out_cv_.notify_all();
113 };
114
115 scale_in_done_event_callback_ = [&]() -> void {
116 std::unique_lock<std::mutex> lock(scale_in_mtx_);
117 scaling_state_ = NodeScaleState::kScaling;
118 scale_in_cv_.notify_all();
119 };
120
121 MS_EXCEPTION_IF_NULL(node_);
122 node_->RegisterEventCallback(core::ClusterEvent::READY_FOR_SCALE_OUT, ready_for_scale_out_event_callback_);
123 node_->RegisterEventCallback(core::ClusterEvent::READY_FOR_SCALE_IN, ready_for_scale_in_event_callback_);
124 node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_SCALE_OUT_DONE, scale_out_done_event_callback_);
125 node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_SCALE_IN_DONE, scale_in_done_event_callback_);
126 }
127
ProcessBeforeScaleOut()128 void FollowerScaler::ProcessBeforeScaleOut() {
129 for (auto &barrier : barriers_before_scale_out_) {
130 MS_LOG(INFO) << "Calling barrier before scaling out for " << barrier.first;
131 barrier.second();
132 }
133 scaling_state_ = NodeScaleState::kWaiting;
134 // Notify scheduler that this node is ready for elastic scaling out.
135 node_->set_ready_for_scale_out();
136 }
137
ProcessBeforeScaleIn()138 void FollowerScaler::ProcessBeforeScaleIn() {
139 for (auto &barrier : barriers_before_scale_in_) {
140 MS_LOG(INFO) << "Calling barrier before scaling in for " << barrier.first;
141 barrier.second();
142 }
143 scaling_state_ = NodeScaleState::kWaiting;
144 // Notify scheduler that this node is ready for elastic scaling in.
145 node_->set_ready_for_scale_in();
146 }
147
ProcessAfterScaleOut()148 void FollowerScaler::ProcessAfterScaleOut() {
149 MS_LOG(INFO) << "Scaling out operation in scheduler is done. Do scaling out for this node.";
150 for (auto &handler : handlers_after_scale_out_) {
151 MS_LOG(INFO) << "Calling scaling out handler for " << handler.first;
152 handler.second();
153 }
154 scaling_state_ = NodeScaleState::kNormal;
155 // Notify scheduler that scaling out of this node is done.
156 node_->set_scale_out_done();
157 }
158
ProcessAfterScaleIn()159 void FollowerScaler::ProcessAfterScaleIn() {
160 MS_LOG(INFO) << "Scaling in operation in scheduler is done. Do scaling in for this node.";
161 for (auto &handler : handlers_after_scale_in_) {
162 MS_LOG(INFO) << "Calling scaling in handler for " << handler.first;
163 handler.second();
164 }
165 scaling_state_ = NodeScaleState::kNormal;
166 // Notify scheduler that scaling out of this node is done.
167 node_->set_scale_in_done();
168 }
169
RegisterBarrierBeforeScaleOut(const std::string & module,const BarrierBeforeScaleOut & barrier)170 void FollowerScaler::RegisterBarrierBeforeScaleOut(const std::string &module, const BarrierBeforeScaleOut &barrier) {
171 (void)barriers_before_scale_out_.try_emplace(module, barrier);
172 }
173
RegisterBarrierBeforeScaleIn(const std::string & module,const BarrierBeforeScaleIn & barrier)174 void FollowerScaler::RegisterBarrierBeforeScaleIn(const std::string &module, const BarrierBeforeScaleIn &barrier) {
175 (void)barriers_before_scale_in_.try_emplace(module, barrier);
176 }
177
RegisterHandlerAfterScaleOut(const std::string & module,const HandlerAfterScaleOut & handler)178 void FollowerScaler::RegisterHandlerAfterScaleOut(const std::string &module, const HandlerAfterScaleOut &handler) {
179 (void)handlers_after_scale_out_.try_emplace(module, handler);
180 }
181
RegisterHandlerAfterScaleIn(const std::string & module,const HandlerAfterScaleIn & handler)182 void FollowerScaler::RegisterHandlerAfterScaleIn(const std::string &module, const HandlerAfterScaleIn &handler) {
183 (void)handlers_after_scale_in_.try_emplace(module, handler);
184 }
185 } // namespace core
186 } // namespace ps
187 } // namespace mindspore
188