1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/round.h"
18 #include <memory>
19 #include <string>
20 #include "fl/server/server.h"
21 #include "fl/server/iteration.h"
22
23 namespace mindspore {
24 namespace fl {
25 namespace server {
26 class Server;
27 class Iteration;
Round(const std::string & name,bool check_timeout,size_t time_window,bool check_count,size_t threshold_count,bool server_num_as_threshold)28 Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
29 bool server_num_as_threshold)
30 : name_(name),
31 check_timeout_(check_timeout),
32 time_window_(time_window),
33 check_count_(check_count),
34 threshold_count_(threshold_count),
35 server_num_as_threshold_(server_num_as_threshold) {}
36
Initialize(const std::shared_ptr<ps::core::CommunicatorBase> & communicator,const TimeOutCb & timeout_cb,const FinishIterCb & finish_iteration_cb)37 void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
38 const FinishIterCb &finish_iteration_cb) {
39 MS_EXCEPTION_IF_NULL(communicator);
40 communicator_ = communicator;
41
42 // Register callback for round kernel.
43 communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
44 MS_ERROR_IF_NULL_WO_RET_VAL(message);
45 LaunchRoundKernel(message);
46 });
47
48 // Callback when the iteration is finished.
49 finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
50 std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
51 finish_iteration_cb(is_iteration_valid, reason);
52 };
53
54 // Callback for finalizing the server. This can only be called once.
55 finalize_cb_ = [&](void) -> void {
56 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
57 (void)communicator_->Stop();
58 };
59
60 if (check_timeout_) {
61 iter_timer_ = std::make_shared<IterationTimer>();
62 MS_EXCEPTION_IF_NULL(iter_timer_);
63
64 // 1.Set the timeout callback for the timer.
65 iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void {
66 std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration.";
67 timeout_cb(is_iteration_valid, reason);
68 });
69
70 // 2.Stopping timer callback which will be set to the round kernel.
71 stop_timer_cb_ = [&](void) -> void {
72 MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
73 MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
74 iter_timer_->Stop();
75 };
76 }
77
78 // Set counter event callbacks for this round if the round kernel is stateful.
79 if (check_count_) {
80 auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
81 auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
82 DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
83 {first_count_handler, last_count_handler});
84 }
85 }
86
ReInitForScaling(uint32_t server_num)87 bool Round::ReInitForScaling(uint32_t server_num) {
88 // If this round requires up-to-date server number as threshold count, update threshold_count_.
89 if (server_num_as_threshold_) {
90 MS_LOG(INFO) << "Round " << name_ << " uses up-to-date server number " << server_num << " as its threshold count.";
91 threshold_count_ = server_num;
92 }
93 if (check_count_) {
94 auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
95 auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
96 DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
97 {first_count_handler, last_count_handler});
98 }
99
100 MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
101 kernel_->InitKernel(threshold_count_);
102 return true;
103 }
104
ReInitForUpdatingHyperParams(size_t updated_threshold_count,size_t updated_time_window)105 bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window) {
106 time_window_ = updated_time_window;
107 threshold_count_ = updated_threshold_count;
108 if (check_count_) {
109 if (!DistributedCountService::GetInstance().ReInitCounter(name_, threshold_count_)) {
110 MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed.";
111 return false;
112 }
113 }
114
115 MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
116 kernel_->InitKernel(threshold_count_);
117 return true;
118 }
119
BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> & kernel)120 void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) {
121 MS_EXCEPTION_IF_NULL(kernel);
122 kernel_ = kernel;
123 kernel_->set_stop_timer_cb(stop_timer_cb_);
124 kernel_->set_finish_iteration_cb(finish_iteration_cb_);
125 return;
126 }
127
LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> & message)128 void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
129 MS_ERROR_IF_NULL_WO_RET_VAL(message);
130 MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
131 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
132
133 std::string reason = "";
134 if (!IsServerAvailable(&reason)) {
135 if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
136 MS_LOG(ERROR) << "Sending response failed.";
137 return;
138 }
139 return;
140 }
141
142 ++Iteration::GetInstance().running_round_num_;
143 AddressPtr input = std::make_shared<Address>();
144 AddressPtr output = std::make_shared<Address>();
145 MS_ERROR_IF_NULL_WO_RET_VAL(input);
146 MS_ERROR_IF_NULL_WO_RET_VAL(output);
147 input->addr = message->data();
148 input->size = message->len();
149 bool ret = kernel_->Launch({input}, {}, {output});
150 if (output->size == 0) {
151 reason = "The output of the round " + name_ + " is empty.";
152 MS_LOG(WARNING) << reason;
153 if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
154 MS_LOG(ERROR) << "Sending response failed.";
155 return;
156 }
157 return;
158 }
159 if (!communicator_->SendResponse(output->addr, output->size, message)) {
160 MS_LOG(ERROR) << "Sending response failed.";
161 return;
162 }
163 kernel_->Release(output);
164
165 // Must send response back no matter what value Launch method returns.
166 if (!ret) {
167 reason = "Launching round kernel of round " + name_ + " failed.";
168 Iteration::GetInstance().NotifyNext(false, reason);
169 }
170 --Iteration::GetInstance().running_round_num_;
171 return;
172 }
173
Reset()174 void Round::Reset() {
175 MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
176 (void)kernel_->Reset();
177 }
178
name() const179 const std::string &Round::name() const { return name_; }
180
threshold_count() const181 size_t Round::threshold_count() const { return threshold_count_; }
182
check_timeout() const183 bool Round::check_timeout() const { return check_timeout_; }
184
time_window() const185 size_t Round::time_window() const { return time_window_; }
186
OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> & message)187 void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
188 MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
189 MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
190 // The timer starts only after the first count event is triggered by DistributedCountService.
191 if (check_timeout_) {
192 MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
193 iter_timer_->Start(std::chrono::milliseconds(time_window_));
194 }
195
196 // Some kernels override the OnFirstCountEvent method.
197 kernel_->OnFirstCountEvent(message);
198 return;
199 }
200
OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> & message)201 void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
202 MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
203 MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
204 // Same as the first count event, the timer must be stopped by DistributedCountService.
205 if (check_timeout_) {
206 MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
207 iter_timer_->Stop();
208 }
209
210 // Some kernels override the OnLastCountEvent method.
211 kernel_->OnLastCountEvent(message);
212 return;
213 }
214
IsServerAvailable(std::string * reason)215 bool Round::IsServerAvailable(std::string *reason) {
216 MS_ERROR_IF_NULL_W_RET_VAL(reason, false);
217 // After one instance is completed, the model should be accessed by clients.
218 if (Iteration::GetInstance().instance_state() == InstanceState::kFinish && name_ == "getModel") {
219 return true;
220 }
221
222 // If the server state is Disable or Finish, refuse the request.
223 if (Iteration::GetInstance().instance_state() == InstanceState::kDisable ||
224 Iteration::GetInstance().instance_state() == InstanceState::kFinish) {
225 MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
226 *reason = ps::kJobNotAvailable;
227 return false;
228 }
229
230 // If the server is still in the process of scaling, reject the request.
231 if (Server::GetInstance().IsSafeMode()) {
232 MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
233 *reason = ps::kClusterSafeMode;
234 return false;
235 }
236 return true;
237 }
238 } // namespace server
239 } // namespace fl
240 } // namespace mindspore
241