1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <memory>
18 #include <string>
19 #include <vector>
20 #include <utility>
21 #include "fl/worker/fl_worker.h"
22 #include "utils/ms_exception.h"
23
24 namespace mindspore {
25 namespace fl {
26 namespace worker {
Run()27 void FLWorker::Run() {
28 if (running_.load()) {
29 return;
30 }
31 running_ = true;
32 worker_num_ = ps::PSContext::instance()->worker_num();
33 server_num_ = ps::PSContext::instance()->server_num();
34 scheduler_ip_ = ps::PSContext::instance()->scheduler_ip();
35 scheduler_port_ = ps::PSContext::instance()->scheduler_port();
36 worker_step_num_per_iteration_ = ps::PSContext::instance()->worker_step_num_per_iteration();
37 ps::PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_;
38 ps::PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
39 ps::PSContext::instance()->cluster_config().initial_worker_num = worker_num_;
40 ps::PSContext::instance()->cluster_config().initial_server_num = server_num_;
41 MS_LOG(INFO) << "Initialize cluster config for worker. Worker number:" << worker_num_
42 << ", Server number:" << server_num_ << ", Scheduler ip:" << scheduler_ip_
43 << ", Scheduler port:" << scheduler_port_
44 << ", Worker training step per iteration:" << worker_step_num_per_iteration_;
45
46 worker_node_ = std::make_shared<ps::core::WorkerNode>();
47 MS_EXCEPTION_IF_NULL(worker_node_);
48
49 worker_node_->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
50 Finalize();
51 running_ = false;
52 try {
53 MS_LOG(EXCEPTION)
54 << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
55 } catch (std::exception &e) {
56 MsException::Instance().SetException();
57 }
58 });
59 worker_node_->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
60 Finalize();
61 running_ = false;
62 try {
63 MS_LOG(EXCEPTION)
64 << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
65 "network building phase.";
66 } catch (std::exception &e) {
67 MsException::Instance().SetException();
68 }
69 });
70
71 InitializeFollowerScaler();
72 if (!worker_node_->Start()) {
73 MS_LOG(EXCEPTION) << "Starting worker node failed.";
74 return;
75 }
76 rank_id_ = worker_node_->rank_id();
77
78 std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
79 return;
80 }
81
Finalize()82 void FLWorker::Finalize() {
83 if (worker_node_ == nullptr) {
84 MS_LOG(INFO) << "The worker is not initialized yet.";
85 return;
86 }
87
88 // In some cases, worker calls the Finish function while other nodes don't. So timeout is acceptable.
89 if (!worker_node_->Finish()) {
90 MS_LOG(WARNING) << "Finishing worker node timeout.";
91 }
92 if (!worker_node_->Stop()) {
93 MS_LOG(ERROR) << "Stopping worker node failed.";
94 return;
95 }
96 }
97
SendToServer(uint32_t server_rank,const void * data,size_t size,ps::core::TcpUserCommand command,std::shared_ptr<std::vector<unsigned char>> * output)98 bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
99 std::shared_ptr<std::vector<unsigned char>> *output) {
100 MS_EXCEPTION_IF_NULL(data);
101 // If the worker is in safemode, do not communicate with server.
102 while (safemode_.load()) {
103 std::this_thread::yield();
104 }
105
106 std::shared_ptr<unsigned char[]> message;
107 std::unique_ptr<unsigned char[]> message_addr = std::make_unique<unsigned char[]>(size);
108 MS_EXCEPTION_IF_NULL(message_addr);
109 message = std::move(message_addr);
110 MS_EXCEPTION_IF_NULL(message);
111
112 uint64_t src_size = size;
113 uint64_t dst_size = size;
114 int ret = memcpy_s(message.get(), dst_size, data, src_size);
115 if (ret != 0) {
116 MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
117 return false;
118 }
119
120 if (output != nullptr) {
121 while (true) {
122 if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output,
123 kWorkerTimeout)) {
124 MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
125 return false;
126 }
127 if (*output == nullptr) {
128 MS_LOG(WARNING) << "Response from server " << server_rank << " is empty.";
129 return false;
130 }
131
132 std::string response_str = std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size());
133 if (response_str == ps::kClusterSafeMode || response_str == ps::kJobNotAvailable) {
134 MS_LOG(INFO) << "The server " << server_rank << " is in safemode or finished.";
135 std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode));
136 } else {
137 break;
138 }
139 }
140 } else {
141 if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
142 kWorkerTimeout)) {
143 MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
144 return false;
145 }
146 }
147 return true;
148 }
149
server_num() const150 uint32_t FLWorker::server_num() const { return server_num_; }
151
worker_num() const152 uint32_t FLWorker::worker_num() const { return worker_num_; }
153
rank_id() const154 uint32_t FLWorker::rank_id() const { return rank_id_; }
155
worker_step_num_per_iteration() const156 uint64_t FLWorker::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
157
running() const158 bool FLWorker::running() const { return running_.load(); }
159
SetIterationRunning()160 void FLWorker::SetIterationRunning() {
161 MS_LOG(INFO) << "Worker iteration starts.";
162 worker_iteration_state_ = IterationState::kRunning;
163 }
164
SetIterationCompleted()165 void FLWorker::SetIterationCompleted() {
166 MS_LOG(INFO) << "Worker iteration completes.";
167 worker_iteration_state_ = IterationState::kCompleted;
168 }
169
set_fl_iteration_num(uint64_t iteration_num)170 void FLWorker::set_fl_iteration_num(uint64_t iteration_num) { iteration_num_ = iteration_num; }
171
fl_iteration_num() const172 uint64_t FLWorker::fl_iteration_num() const { return iteration_num_.load(); }
173
set_data_size(int data_size)174 void FLWorker::set_data_size(int data_size) { data_size_ = data_size; }
175
data_size() const176 int FLWorker::data_size() const { return data_size_; }
177
fl_name() const178 std::string FLWorker::fl_name() const { return ps::kServerModeFL; }
179
fl_id() const180 std::string FLWorker::fl_id() const { return std::to_string(rank_id_); }
181
InitializeFollowerScaler()182 void FLWorker::InitializeFollowerScaler() {
183 MS_EXCEPTION_IF_NULL(worker_node_);
184 if (!worker_node_->InitFollowerScaler()) {
185 MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
186 return;
187 }
188
189 // Set scaling barriers before scaling.
190 worker_node_->RegisterFollowerScalerBarrierBeforeScaleOut("WorkerPipeline",
191 std::bind(&FLWorker::ProcessBeforeScalingOut, this));
192 worker_node_->RegisterFollowerScalerBarrierBeforeScaleIn("WorkerPipeline",
193 std::bind(&FLWorker::ProcessBeforeScalingIn, this));
194
195 // Set handlers after scheduler scaling operations are done.
196 worker_node_->RegisterFollowerScalerHandlerAfterScaleOut("WorkerPipeline",
197 std::bind(&FLWorker::ProcessAfterScalingOut, this));
198 worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
199 std::bind(&FLWorker::ProcessAfterScalingIn, this));
200 worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
201 std::bind(&FLWorker::HandleIterationRunningEvent, this));
202 worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
203 std::bind(&FLWorker::HandleIterationCompletedEvent, this));
204 }
205
HandleIterationRunningEvent()206 void FLWorker::HandleIterationRunningEvent() {
207 MS_LOG(INFO) << "Server iteration starts, safemode is " << safemode_.load();
208 server_iteration_state_ = IterationState::kRunning;
209 if (safemode_.load() == true) {
210 safemode_ = false;
211 }
212 }
213
HandleIterationCompletedEvent()214 void FLWorker::HandleIterationCompletedEvent() {
215 MS_LOG(INFO) << "Server iteration completes";
216 server_iteration_state_ = IterationState::kCompleted;
217 }
218
ProcessBeforeScalingOut()219 void FLWorker::ProcessBeforeScalingOut() {
220 MS_LOG(INFO) << "Starting Worker scaling out barrier.";
221 while (server_iteration_state_.load() != IterationState::kCompleted ||
222 worker_iteration_state_.load() != IterationState::kCompleted) {
223 std::this_thread::yield();
224 }
225 MS_LOG(INFO) << "Ending Worker scaling out barrier. Switch to safemode.";
226 safemode_ = true;
227 }
228
ProcessBeforeScalingIn()229 void FLWorker::ProcessBeforeScalingIn() {
230 MS_LOG(INFO) << "Starting Worker scaling in barrier.";
231 while (server_iteration_state_.load() != IterationState::kCompleted ||
232 worker_iteration_state_.load() != IterationState::kCompleted) {
233 std::this_thread::yield();
234 }
235 MS_LOG(INFO) << "Ending Worker scaling in barrier. Switch to safemode.";
236 safemode_ = true;
237 }
238
ProcessAfterScalingOut()239 void FLWorker::ProcessAfterScalingOut() {
240 MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_);
241 MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker.";
242 server_num_ = IntToUint(worker_node_->server_num());
243 worker_num_ = IntToUint(worker_node_->worker_num());
244 MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is "
245 << server_num_ << ". Exit safemode.";
246 std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
247 safemode_ = false;
248 }
249
ProcessAfterScalingIn()250 void FLWorker::ProcessAfterScalingIn() {
251 MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_);
252 MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker.";
253 server_num_ = IntToUint(worker_node_->server_num());
254 worker_num_ = IntToUint(worker_node_->worker_num());
255 MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_
256 << ". Exit safemode.";
257 std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
258 safemode_ = false;
259 }
260 } // namespace worker
261 } // namespace fl
262 } // namespace mindspore
263