• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/round.h"
18 #include <memory>
19 #include <string>
20 #include "fl/server/server.h"
21 #include "fl/server/iteration.h"
22 
23 namespace mindspore {
24 namespace fl {
25 namespace server {
26 class Server;
27 class Iteration;
Round(const std::string & name,bool check_timeout,size_t time_window,bool check_count,size_t threshold_count,bool server_num_as_threshold)28 Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
29              bool server_num_as_threshold)
30     : name_(name),
31       check_timeout_(check_timeout),
32       time_window_(time_window),
33       check_count_(check_count),
34       threshold_count_(threshold_count),
35       server_num_as_threshold_(server_num_as_threshold) {}
36 
Initialize(const std::shared_ptr<ps::core::CommunicatorBase> & communicator,const TimeOutCb & timeout_cb,const FinishIterCb & finish_iteration_cb)37 void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
38                        const FinishIterCb &finish_iteration_cb) {
39   MS_EXCEPTION_IF_NULL(communicator);
40   communicator_ = communicator;
41 
42   // Register callback for round kernel.
43   communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
44     MS_ERROR_IF_NULL_WO_RET_VAL(message);
45     LaunchRoundKernel(message);
46   });
47 
48   // Callback when the iteration is finished.
49   finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
50     std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
51     finish_iteration_cb(is_iteration_valid, reason);
52   };
53 
54   // Callback for finalizing the server. This can only be called once.
55   finalize_cb_ = [&](void) -> void {
56     MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
57     (void)communicator_->Stop();
58   };
59 
60   if (check_timeout_) {
61     iter_timer_ = std::make_shared<IterationTimer>();
62     MS_EXCEPTION_IF_NULL(iter_timer_);
63 
64     // 1.Set the timeout callback for the timer.
65     iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void {
66       std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration.";
67       timeout_cb(is_iteration_valid, reason);
68     });
69 
70     // 2.Stopping timer callback which will be set to the round kernel.
71     stop_timer_cb_ = [&](void) -> void {
72       MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
73       MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
74       iter_timer_->Stop();
75     };
76   }
77 
78   // Set counter event callbacks for this round if the round kernel is stateful.
79   if (check_count_) {
80     auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
81     auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
82     DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
83                                                            {first_count_handler, last_count_handler});
84   }
85 }
86 
ReInitForScaling(uint32_t server_num)87 bool Round::ReInitForScaling(uint32_t server_num) {
88   // If this round requires up-to-date server number as threshold count, update threshold_count_.
89   if (server_num_as_threshold_) {
90     MS_LOG(INFO) << "Round " << name_ << " uses up-to-date server number " << server_num << " as its threshold count.";
91     threshold_count_ = server_num;
92   }
93   if (check_count_) {
94     auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
95     auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
96     DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
97                                                            {first_count_handler, last_count_handler});
98   }
99 
100   MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
101   kernel_->InitKernel(threshold_count_);
102   return true;
103 }
104 
ReInitForUpdatingHyperParams(size_t updated_threshold_count,size_t updated_time_window)105 bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window) {
106   time_window_ = updated_time_window;
107   threshold_count_ = updated_threshold_count;
108   if (check_count_) {
109     if (!DistributedCountService::GetInstance().ReInitCounter(name_, threshold_count_)) {
110       MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed.";
111       return false;
112     }
113   }
114 
115   MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
116   kernel_->InitKernel(threshold_count_);
117   return true;
118 }
119 
BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> & kernel)120 void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) {
121   MS_EXCEPTION_IF_NULL(kernel);
122   kernel_ = kernel;
123   kernel_->set_stop_timer_cb(stop_timer_cb_);
124   kernel_->set_finish_iteration_cb(finish_iteration_cb_);
125   return;
126 }
127 
LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> & message)128 void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
129   MS_ERROR_IF_NULL_WO_RET_VAL(message);
130   MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
131   MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
132 
133   std::string reason = "";
134   if (!IsServerAvailable(&reason)) {
135     if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
136       MS_LOG(ERROR) << "Sending response failed.";
137       return;
138     }
139     return;
140   }
141 
142   ++Iteration::GetInstance().running_round_num_;
143   AddressPtr input = std::make_shared<Address>();
144   AddressPtr output = std::make_shared<Address>();
145   MS_ERROR_IF_NULL_WO_RET_VAL(input);
146   MS_ERROR_IF_NULL_WO_RET_VAL(output);
147   input->addr = message->data();
148   input->size = message->len();
149   bool ret = kernel_->Launch({input}, {}, {output});
150   if (output->size == 0) {
151     reason = "The output of the round " + name_ + " is empty.";
152     MS_LOG(WARNING) << reason;
153     if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
154       MS_LOG(ERROR) << "Sending response failed.";
155       return;
156     }
157     return;
158   }
159   if (!communicator_->SendResponse(output->addr, output->size, message)) {
160     MS_LOG(ERROR) << "Sending response failed.";
161     return;
162   }
163   kernel_->Release(output);
164 
165   // Must send response back no matter what value Launch method returns.
166   if (!ret) {
167     reason = "Launching round kernel of round " + name_ + " failed.";
168     Iteration::GetInstance().NotifyNext(false, reason);
169   }
170   --Iteration::GetInstance().running_round_num_;
171   return;
172 }
173 
Reset()174 void Round::Reset() {
175   MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
176   (void)kernel_->Reset();
177 }
178 
name() const179 const std::string &Round::name() const { return name_; }
180 
threshold_count() const181 size_t Round::threshold_count() const { return threshold_count_; }
182 
check_timeout() const183 bool Round::check_timeout() const { return check_timeout_; }
184 
time_window() const185 size_t Round::time_window() const { return time_window_; }
186 
OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> & message)187 void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
188   MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
189   MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
190   // The timer starts only after the first count event is triggered by DistributedCountService.
191   if (check_timeout_) {
192     MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
193     iter_timer_->Start(std::chrono::milliseconds(time_window_));
194   }
195 
196   // Some kernels override the OnFirstCountEvent method.
197   kernel_->OnFirstCountEvent(message);
198   return;
199 }
200 
OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> & message)201 void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
202   MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
203   MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
204   // Same as the first count event, the timer must be stopped by DistributedCountService.
205   if (check_timeout_) {
206     MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
207     iter_timer_->Stop();
208   }
209 
210   // Some kernels override the OnLastCountEvent method.
211   kernel_->OnLastCountEvent(message);
212   return;
213 }
214 
IsServerAvailable(std::string * reason)215 bool Round::IsServerAvailable(std::string *reason) {
216   MS_ERROR_IF_NULL_W_RET_VAL(reason, false);
217   // After one instance is completed, the model should be accessed by clients.
218   if (Iteration::GetInstance().instance_state() == InstanceState::kFinish && name_ == "getModel") {
219     return true;
220   }
221 
222   // If the server state is Disable or Finish, refuse the request.
223   if (Iteration::GetInstance().instance_state() == InstanceState::kDisable ||
224       Iteration::GetInstance().instance_state() == InstanceState::kFinish) {
225     MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
226     *reason = ps::kJobNotAvailable;
227     return false;
228   }
229 
230   // If the server is still in the process of scaling, reject the request.
231   if (Server::GetInstance().IsSafeMode()) {
232     MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
233     *reason = ps::kClusterSafeMode;
234     return false;
235   }
236   return true;
237 }
238 }  // namespace server
239 }  // namespace fl
240 }  // namespace mindspore
241