• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 #include <string>
19 #include <vector>
20 #include <utility>
21 #include "fl/worker/fl_worker.h"
22 #include "utils/ms_exception.h"
23 
24 namespace mindspore {
25 namespace fl {
26 namespace worker {
Run()27 void FLWorker::Run() {
28   if (running_.load()) {
29     return;
30   }
31   running_ = true;
32   worker_num_ = ps::PSContext::instance()->worker_num();
33   server_num_ = ps::PSContext::instance()->server_num();
34   scheduler_ip_ = ps::PSContext::instance()->scheduler_ip();
35   scheduler_port_ = ps::PSContext::instance()->scheduler_port();
36   worker_step_num_per_iteration_ = ps::PSContext::instance()->worker_step_num_per_iteration();
37   ps::PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_;
38   ps::PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
39   ps::PSContext::instance()->cluster_config().initial_worker_num = worker_num_;
40   ps::PSContext::instance()->cluster_config().initial_server_num = server_num_;
41   MS_LOG(INFO) << "Initialize cluster config for worker. Worker number:" << worker_num_
42                << ", Server number:" << server_num_ << ", Scheduler ip:" << scheduler_ip_
43                << ", Scheduler port:" << scheduler_port_
44                << ", Worker training step per iteration:" << worker_step_num_per_iteration_;
45 
46   worker_node_ = std::make_shared<ps::core::WorkerNode>();
47   MS_EXCEPTION_IF_NULL(worker_node_);
48 
49   worker_node_->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
50     Finalize();
51     running_ = false;
52     try {
53       MS_LOG(EXCEPTION)
54         << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
55     } catch (std::exception &e) {
56       MsException::Instance().SetException();
57     }
58   });
59   worker_node_->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
60     Finalize();
61     running_ = false;
62     try {
63       MS_LOG(EXCEPTION)
64         << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
65            "network building phase.";
66     } catch (std::exception &e) {
67       MsException::Instance().SetException();
68     }
69   });
70 
71   InitializeFollowerScaler();
72   if (!worker_node_->Start()) {
73     MS_LOG(EXCEPTION) << "Starting worker node failed.";
74     return;
75   }
76   rank_id_ = worker_node_->rank_id();
77 
78   std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
79   return;
80 }
81 
Finalize()82 void FLWorker::Finalize() {
83   if (worker_node_ == nullptr) {
84     MS_LOG(INFO) << "The worker is not initialized yet.";
85     return;
86   }
87 
88   // In some cases, worker calls the Finish function while other nodes don't. So timeout is acceptable.
89   if (!worker_node_->Finish()) {
90     MS_LOG(WARNING) << "Finishing worker node timeout.";
91   }
92   if (!worker_node_->Stop()) {
93     MS_LOG(ERROR) << "Stopping worker node failed.";
94     return;
95   }
96 }
97 
SendToServer(uint32_t server_rank,const void * data,size_t size,ps::core::TcpUserCommand command,std::shared_ptr<std::vector<unsigned char>> * output)98 bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
99                             std::shared_ptr<std::vector<unsigned char>> *output) {
100   MS_EXCEPTION_IF_NULL(data);
101   // If the worker is in safemode, do not communicate with server.
102   while (safemode_.load()) {
103     std::this_thread::yield();
104   }
105 
106   std::shared_ptr<unsigned char[]> message;
107   std::unique_ptr<unsigned char[]> message_addr = std::make_unique<unsigned char[]>(size);
108   MS_EXCEPTION_IF_NULL(message_addr);
109   message = std::move(message_addr);
110   MS_EXCEPTION_IF_NULL(message);
111 
112   uint64_t src_size = size;
113   uint64_t dst_size = size;
114   int ret = memcpy_s(message.get(), dst_size, data, src_size);
115   if (ret != 0) {
116     MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
117     return false;
118   }
119 
120   if (output != nullptr) {
121     while (true) {
122       if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output,
123                               kWorkerTimeout)) {
124         MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
125         return false;
126       }
127       if (*output == nullptr) {
128         MS_LOG(WARNING) << "Response from server " << server_rank << " is empty.";
129         return false;
130       }
131 
132       std::string response_str = std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size());
133       if (response_str == ps::kClusterSafeMode || response_str == ps::kJobNotAvailable) {
134         MS_LOG(INFO) << "The server " << server_rank << " is in safemode or finished.";
135         std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode));
136       } else {
137         break;
138       }
139     }
140   } else {
141     if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
142                             kWorkerTimeout)) {
143       MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
144       return false;
145     }
146   }
147   return true;
148 }
149 
server_num() const150 uint32_t FLWorker::server_num() const { return server_num_; }
151 
worker_num() const152 uint32_t FLWorker::worker_num() const { return worker_num_; }
153 
rank_id() const154 uint32_t FLWorker::rank_id() const { return rank_id_; }
155 
worker_step_num_per_iteration() const156 uint64_t FLWorker::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
157 
running() const158 bool FLWorker::running() const { return running_.load(); }
159 
SetIterationRunning()160 void FLWorker::SetIterationRunning() {
161   MS_LOG(INFO) << "Worker iteration starts.";
162   worker_iteration_state_ = IterationState::kRunning;
163 }
164 
SetIterationCompleted()165 void FLWorker::SetIterationCompleted() {
166   MS_LOG(INFO) << "Worker iteration completes.";
167   worker_iteration_state_ = IterationState::kCompleted;
168 }
169 
set_fl_iteration_num(uint64_t iteration_num)170 void FLWorker::set_fl_iteration_num(uint64_t iteration_num) { iteration_num_ = iteration_num; }
171 
fl_iteration_num() const172 uint64_t FLWorker::fl_iteration_num() const { return iteration_num_.load(); }
173 
set_data_size(int data_size)174 void FLWorker::set_data_size(int data_size) { data_size_ = data_size; }
175 
data_size() const176 int FLWorker::data_size() const { return data_size_; }
177 
fl_name() const178 std::string FLWorker::fl_name() const { return ps::kServerModeFL; }
179 
fl_id() const180 std::string FLWorker::fl_id() const { return std::to_string(rank_id_); }
181 
InitializeFollowerScaler()182 void FLWorker::InitializeFollowerScaler() {
183   MS_EXCEPTION_IF_NULL(worker_node_);
184   if (!worker_node_->InitFollowerScaler()) {
185     MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
186     return;
187   }
188 
189   // Set scaling barriers before scaling.
190   worker_node_->RegisterFollowerScalerBarrierBeforeScaleOut("WorkerPipeline",
191                                                             std::bind(&FLWorker::ProcessBeforeScalingOut, this));
192   worker_node_->RegisterFollowerScalerBarrierBeforeScaleIn("WorkerPipeline",
193                                                            std::bind(&FLWorker::ProcessBeforeScalingIn, this));
194 
195   // Set handlers after scheduler scaling operations are done.
196   worker_node_->RegisterFollowerScalerHandlerAfterScaleOut("WorkerPipeline",
197                                                            std::bind(&FLWorker::ProcessAfterScalingOut, this));
198   worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
199                                                           std::bind(&FLWorker::ProcessAfterScalingIn, this));
200   worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
201                                             std::bind(&FLWorker::HandleIterationRunningEvent, this));
202   worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
203                                             std::bind(&FLWorker::HandleIterationCompletedEvent, this));
204 }
205 
HandleIterationRunningEvent()206 void FLWorker::HandleIterationRunningEvent() {
207   MS_LOG(INFO) << "Server iteration starts, safemode is " << safemode_.load();
208   server_iteration_state_ = IterationState::kRunning;
209   if (safemode_.load() == true) {
210     safemode_ = false;
211   }
212 }
213 
HandleIterationCompletedEvent()214 void FLWorker::HandleIterationCompletedEvent() {
215   MS_LOG(INFO) << "Server iteration completes";
216   server_iteration_state_ = IterationState::kCompleted;
217 }
218 
ProcessBeforeScalingOut()219 void FLWorker::ProcessBeforeScalingOut() {
220   MS_LOG(INFO) << "Starting Worker scaling out barrier.";
221   while (server_iteration_state_.load() != IterationState::kCompleted ||
222          worker_iteration_state_.load() != IterationState::kCompleted) {
223     std::this_thread::yield();
224   }
225   MS_LOG(INFO) << "Ending Worker scaling out barrier. Switch to safemode.";
226   safemode_ = true;
227 }
228 
ProcessBeforeScalingIn()229 void FLWorker::ProcessBeforeScalingIn() {
230   MS_LOG(INFO) << "Starting Worker scaling in barrier.";
231   while (server_iteration_state_.load() != IterationState::kCompleted ||
232          worker_iteration_state_.load() != IterationState::kCompleted) {
233     std::this_thread::yield();
234   }
235   MS_LOG(INFO) << "Ending Worker scaling in barrier. Switch to safemode.";
236   safemode_ = true;
237 }
238 
ProcessAfterScalingOut()239 void FLWorker::ProcessAfterScalingOut() {
240   MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_);
241   MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker.";
242   server_num_ = IntToUint(worker_node_->server_num());
243   worker_num_ = IntToUint(worker_node_->worker_num());
244   MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is "
245                << server_num_ << ". Exit safemode.";
246   std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
247   safemode_ = false;
248 }
249 
ProcessAfterScalingIn()250 void FLWorker::ProcessAfterScalingIn() {
251   MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_);
252   MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker.";
253   server_num_ = IntToUint(worker_node_->server_num());
254   worker_num_ = IntToUint(worker_node_->worker_num());
255   MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_
256                << ". Exit safemode.";
257   std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
258   safemode_ = false;
259 }
260 }  // namespace worker
261 }  // namespace fl
262 }  // namespace mindspore
263