• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/follower_scaler.h"
18 #include "ps/core/communicator/tcp_communicator.h"
19 
20 namespace mindspore {
21 namespace ps {
22 namespace core {
FollowerScaler(AbstractNode * node)23 FollowerScaler::FollowerScaler(AbstractNode *node)
24     : node_(node), scaling_state_(NodeScaleState::kNormal), running_(true) {
25   process_before_scale_out_thread_ = std::thread([&]() {
26     while (running_.load()) {
27       std::unique_lock<std::mutex> lock(scale_out_mtx_);
28       scale_out_cv_.wait(
29         lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kPreparing; });
30       if (!running_.load()) {
31         break;
32       }
33       ProcessBeforeScaleOut();
34     }
35   });
36 
37   process_before_scale_in_thread_ = std::thread([&]() {
38     while (running_.load()) {
39       std::unique_lock<std::mutex> lock(scale_in_mtx_);
40       scale_in_cv_.wait(
41         lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kPreparing; });
42       // In scaling in scenario, abstract node will trigger CLUSTER_SCALE_IN_DONE event in the same thread if this node
43       // is the one to be scaled in, so we need to release the lock here to avoid dead lock.
44       lock.unlock();
45       if (!running_.load()) {
46         break;
47       }
48       ProcessBeforeScaleIn();
49     }
50   });
51 
52   process_after_scale_out_thread_ = std::thread([&]() {
53     while (running_.load()) {
54       std::unique_lock<std::mutex> lock(scale_out_mtx_);
55       scale_out_cv_.wait(
56         lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kScaling; });
57       if (!running_.load()) {
58         break;
59       }
60       ProcessAfterScaleOut();
61     }
62   });
63 
64   process_after_scale_in_thread_ = std::thread([&]() {
65     while (running_.load()) {
66       std::unique_lock<std::mutex> lock(scale_in_mtx_);
67       scale_in_cv_.wait(
68         lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kScaling; });
69       if (!running_.load()) {
70         break;
71       }
72       ProcessAfterScaleIn();
73     }
74   });
75 }
76 
~FollowerScaler()77 FollowerScaler::~FollowerScaler() {
78   running_ = false;
79   scale_out_cv_.notify_all();
80   scale_in_cv_.notify_all();
81   if (process_before_scale_out_thread_.joinable()) {
82     process_before_scale_out_thread_.join();
83   }
84   if (process_before_scale_in_thread_.joinable()) {
85     process_before_scale_in_thread_.join();
86   }
87   if (process_after_scale_out_thread_.joinable()) {
88     process_after_scale_out_thread_.join();
89   }
90   if (process_after_scale_in_thread_.joinable()) {
91     process_after_scale_in_thread_.join();
92   }
93 }
94 
RegisterScaleEventCallbacks()95 void FollowerScaler::RegisterScaleEventCallbacks() {
96   ready_for_scale_out_event_callback_ = [&]() -> void {
97     // Notify the thread which will call the barriers.
98     std::unique_lock<std::mutex> lock(scale_out_mtx_);
99     scaling_state_ = NodeScaleState::kPreparing;
100     scale_out_cv_.notify_all();
101   };
102 
103   ready_for_scale_in_event_callback_ = [&]() -> void {
104     std::unique_lock<std::mutex> lock(scale_in_mtx_);
105     scaling_state_ = NodeScaleState::kPreparing;
106     scale_in_cv_.notify_all();
107   };
108 
109   scale_out_done_event_callback_ = [&]() -> void {
110     std::unique_lock<std::mutex> lock(scale_out_mtx_);
111     scaling_state_ = NodeScaleState::kScaling;
112     scale_out_cv_.notify_all();
113   };
114 
115   scale_in_done_event_callback_ = [&]() -> void {
116     std::unique_lock<std::mutex> lock(scale_in_mtx_);
117     scaling_state_ = NodeScaleState::kScaling;
118     scale_in_cv_.notify_all();
119   };
120 
121   MS_EXCEPTION_IF_NULL(node_);
122   node_->RegisterEventCallback(core::ClusterEvent::READY_FOR_SCALE_OUT, ready_for_scale_out_event_callback_);
123   node_->RegisterEventCallback(core::ClusterEvent::READY_FOR_SCALE_IN, ready_for_scale_in_event_callback_);
124   node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_SCALE_OUT_DONE, scale_out_done_event_callback_);
125   node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_SCALE_IN_DONE, scale_in_done_event_callback_);
126 }
127 
ProcessBeforeScaleOut()128 void FollowerScaler::ProcessBeforeScaleOut() {
129   for (auto &barrier : barriers_before_scale_out_) {
130     MS_LOG(INFO) << "Calling barrier before scaling out for " << barrier.first;
131     barrier.second();
132   }
133   scaling_state_ = NodeScaleState::kWaiting;
134   // Notify scheduler that this node is ready for elastic scaling out.
135   node_->set_ready_for_scale_out();
136 }
137 
ProcessBeforeScaleIn()138 void FollowerScaler::ProcessBeforeScaleIn() {
139   for (auto &barrier : barriers_before_scale_in_) {
140     MS_LOG(INFO) << "Calling barrier before scaling in for " << barrier.first;
141     barrier.second();
142   }
143   scaling_state_ = NodeScaleState::kWaiting;
144   // Notify scheduler that this node is ready for elastic scaling in.
145   node_->set_ready_for_scale_in();
146 }
147 
ProcessAfterScaleOut()148 void FollowerScaler::ProcessAfterScaleOut() {
149   MS_LOG(INFO) << "Scaling out operation in scheduler is done. Do scaling out for this node.";
150   for (auto &handler : handlers_after_scale_out_) {
151     MS_LOG(INFO) << "Calling scaling out handler for " << handler.first;
152     handler.second();
153   }
154   scaling_state_ = NodeScaleState::kNormal;
155   // Notify scheduler that scaling out of this node is done.
156   node_->set_scale_out_done();
157 }
158 
ProcessAfterScaleIn()159 void FollowerScaler::ProcessAfterScaleIn() {
160   MS_LOG(INFO) << "Scaling in operation in scheduler is done. Do scaling in for this node.";
161   for (auto &handler : handlers_after_scale_in_) {
162     MS_LOG(INFO) << "Calling scaling in handler for " << handler.first;
163     handler.second();
164   }
165   scaling_state_ = NodeScaleState::kNormal;
166   // Notify scheduler that scaling out of this node is done.
167   node_->set_scale_in_done();
168 }
169 
RegisterBarrierBeforeScaleOut(const std::string & module,const BarrierBeforeScaleOut & barrier)170 void FollowerScaler::RegisterBarrierBeforeScaleOut(const std::string &module, const BarrierBeforeScaleOut &barrier) {
171   (void)barriers_before_scale_out_.try_emplace(module, barrier);
172 }
173 
RegisterBarrierBeforeScaleIn(const std::string & module,const BarrierBeforeScaleIn & barrier)174 void FollowerScaler::RegisterBarrierBeforeScaleIn(const std::string &module, const BarrierBeforeScaleIn &barrier) {
175   (void)barriers_before_scale_in_.try_emplace(module, barrier);
176 }
177 
RegisterHandlerAfterScaleOut(const std::string & module,const HandlerAfterScaleOut & handler)178 void FollowerScaler::RegisterHandlerAfterScaleOut(const std::string &module, const HandlerAfterScaleOut &handler) {
179   (void)handlers_after_scale_out_.try_emplace(module, handler);
180 }
181 
RegisterHandlerAfterScaleIn(const std::string & module,const HandlerAfterScaleIn & handler)182 void FollowerScaler::RegisterHandlerAfterScaleIn(const std::string &module, const HandlerAfterScaleIn &handler) {
183   (void)handlers_after_scale_in_.try_emplace(module, handler);
184 }
185 }  // namespace core
186 }  // namespace ps
187 }  // namespace mindspore
188