1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/executor.h"
18 #include <set>
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 namespace mindspore {
24 namespace fl {
25 namespace server {
Initialize(const FuncGraphPtr & func_graph,size_t aggregation_count)26 void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
27 MS_EXCEPTION_IF_NULL(func_graph);
28 if (aggregation_count == 0) {
29 MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
30 return;
31 }
32 aggregation_count_ = aggregation_count;
33
34 // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers.
35 bool ret = InitParamAggregator(func_graph);
36 if (!ret) {
37 MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
38 return;
39 }
40 initialized_ = true;
41 return;
42 }
43
ReInitForScaling()44 bool Executor::ReInitForScaling() {
45 auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(),
46 [](auto param_aggr) { return !param_aggr.second->ReInitForScaling(); });
47 if (result != param_aggrs_.end()) {
48 MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
49 return false;
50 }
51 return true;
52 }
53
ReInitForUpdatingHyperParams(size_t aggr_threshold)54 bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
55 aggregation_count_ = aggr_threshold;
56 auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(), [this](auto param_aggr) {
57 return !param_aggr.second->ReInitForUpdatingHyperParams(aggregation_count_);
58 });
59 if (result != param_aggrs_.end()) {
60 MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
61 return false;
62 }
63 return true;
64 }
65
initialized() const66 bool Executor::initialized() const { return initialized_; }
67
HandlePush(const std::string & param_name,const UploadData & upload_data)68 bool Executor::HandlePush(const std::string ¶m_name, const UploadData &upload_data) {
69 MS_LOG(DEBUG) << "Do Push for parameter " << param_name;
70 if (param_aggrs_.count(param_name) == 0) {
71 MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
72 return false;
73 }
74
75 std::mutex &mtx = parameter_mutex_[param_name];
76 std::unique_lock<std::mutex> lock(mtx);
77 auto ¶m_aggr = param_aggrs_[param_name];
78 MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
79 // Push operation needs to wait until the pulling process is done.
80 while (!param_aggr->IsPullingDone()) {
81 lock.unlock();
82 std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
83 lock.lock();
84 }
85
86 // 1.Update data with the uploaded data of the worker.
87 if (!param_aggr->UpdateData(upload_data)) {
88 MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
89 return false;
90 }
91 // 2.Launch aggregation for this trainable parameter.
92 if (!param_aggr->LaunchAggregators()) {
93 MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
94 return false;
95 }
96 if (param_aggr->IsAggregationDone()) {
97 // 3.After the aggregation is done, optimize the trainable parameter.
98 if (!param_aggr->LaunchOptimizers()) {
99 MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed.";
100 return false;
101 }
102 // 4.Reset pulling and aggregation status after optimizing is done.
103 param_aggr->ResetPullingStatus();
104 param_aggr->ResetAggregationStatus();
105 }
106 return true;
107 }
108
HandleModelUpdate(const std::string & param_name,const UploadData & upload_data)109 bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) {
110 MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
111 if (param_aggrs_.count(param_name) == 0) {
112 // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we
113 // just print a warning log and return true.
114 MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
115 return true;
116 }
117
118 std::mutex &mtx = parameter_mutex_[param_name];
119 std::unique_lock<std::mutex> lock(mtx);
120 auto ¶m_aggr = param_aggrs_[param_name];
121 MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
122 if (!param_aggr->UpdateData(upload_data)) {
123 MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
124 return false;
125 }
126 // Different from Push, UpdateModel doesn't need to checkout the aggregation status.
127 if (!param_aggr->LaunchAggregators()) {
128 MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
129 return false;
130 }
131 return true;
132 }
133
HandleModelUpdateAsync(const std::map<std::string,UploadData> & feature_map)134 bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map) {
135 std::unique_lock<std::mutex> model_lock(model_mutex_);
136 for (const auto &trainable_param : feature_map) {
137 const std::string ¶m_name = trainable_param.first;
138 if (param_aggrs_.count(param_name) == 0) {
139 MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
140 continue;
141 }
142
143 std::mutex &mtx = parameter_mutex_[param_name];
144 std::unique_lock<std::mutex> lock(mtx);
145 auto ¶m_aggr = param_aggrs_[param_name];
146 MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
147 const UploadData &upload_data = trainable_param.second;
148 if (!param_aggr->UpdateData(upload_data)) {
149 MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
150 return false;
151 }
152 if (!param_aggr->LaunchAggregators()) {
153 MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
154 return false;
155 }
156 }
157 return true;
158 }
159
HandlePushWeight(const std::map<std::string,Address> & feature_map)160 bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
161 for (const auto &trainable_param : feature_map) {
162 const std::string ¶m_name = trainable_param.first;
163 if (param_aggrs_.count(param_name) == 0) {
164 MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server.";
165 continue;
166 }
167
168 std::mutex &mtx = parameter_mutex_[param_name];
169 std::unique_lock<std::mutex> lock(mtx);
170 auto ¶m_aggr = param_aggrs_[param_name];
171 MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
172 AddressPtr old_weight = param_aggr->GetWeight();
173 const Address &new_weight = trainable_param.second;
174 MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false);
175 MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false);
176 MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false);
177 int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
178 if (ret != 0) {
179 MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
180 return false;
181 }
182 }
183 return true;
184 }
185
HandlePull(const std::string & param_name)186 AddressPtr Executor::HandlePull(const std::string ¶m_name) {
187 MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
188 if (param_aggrs_.count(param_name) == 0) {
189 MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
190 return nullptr;
191 }
192
193 std::mutex &mtx = parameter_mutex_[param_name];
194 std::unique_lock<std::mutex> lock(mtx);
195 auto ¶m_aggr = param_aggrs_[param_name];
196 MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, nullptr);
197 // Pulling must wait until the optimizing process is done.
198 while (!param_aggr->IsOptimizingDone()) {
199 lock.unlock();
200 std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
201 lock.lock();
202 }
203 AddressPtr addr = param_aggr->Pull();
204 // If this Pull is the last one, reset pulling and optimizing status.
205 if (param_aggr->IsPullingDone()) {
206 param_aggr->ResetOptimizingStatus();
207 }
208 return addr;
209 }
210
HandlePullWeight(const std::vector<std::string> & param_names)211 std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> ¶m_names) {
212 std::map<std::string, AddressPtr> weights;
213 for (const auto ¶m_name : param_names) {
214 if (param_aggrs_.count(param_name) == 0) {
215 MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server.";
216 return weights;
217 }
218
219 std::mutex &mtx = parameter_mutex_[param_name];
220 std::unique_lock<std::mutex> lock(mtx);
221 const auto ¶m_aggr = param_aggrs_[param_name];
222 MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, weights);
223 AddressPtr addr = param_aggr->GetWeight();
224 if (addr == nullptr) {
225 MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
226 continue;
227 }
228 weights[param_name] = addr;
229 }
230 return weights;
231 }
232
IsAllWeightAggregationDone()233 bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }
234
IsWeightAggrDone(const std::vector<std::string> & param_names)235 bool Executor::IsWeightAggrDone(const std::vector<std::string> ¶m_names) {
236 for (const auto &name : param_names) {
237 if (param_aggrs_.count(name) == 0) {
238 MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
239 return false;
240 }
241
242 std::mutex &mtx = parameter_mutex_[name];
243 std::unique_lock<std::mutex> lock(mtx);
244 auto ¶m_aggr = param_aggrs_[name];
245 MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
246 if (!param_aggr->requires_aggr()) {
247 continue;
248 }
249 if (!param_aggr->IsAggregationDone()) {
250 MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
251 return false;
252 }
253 }
254 return true;
255 }
256
ResetAggregationStatus()257 void Executor::ResetAggregationStatus() {
258 for (const auto ¶m_name : param_names_) {
259 std::mutex &mtx = parameter_mutex_[param_name];
260 std::unique_lock<std::mutex> lock(mtx);
261 auto ¶m_aggr = param_aggrs_[param_name];
262 MS_ERROR_IF_NULL_WO_RET_VAL(param_aggr);
263 param_aggr->ResetAggregationStatus();
264 }
265 return;
266 }
267
GetModel()268 std::map<std::string, AddressPtr> Executor::GetModel() {
269 std::map<std::string, AddressPtr> model = {};
270 for (const auto &name : param_names_) {
271 std::mutex &mtx = parameter_mutex_[name];
272 std::unique_lock<std::mutex> lock(mtx);
273 AddressPtr addr = param_aggrs_[name]->GetWeight();
274 if (addr == nullptr) {
275 MS_LOG(WARNING) << "Get weight of " << name << " failed.";
276 continue;
277 }
278 model[name] = addr;
279 }
280 return model;
281 }
282
param_names() const283 const std::vector<std::string> &Executor::param_names() const { return param_names_; }
284
Unmask()285 bool Executor::Unmask() {
286 #ifdef ENABLE_ARMOUR
287 auto model = GetModel();
288 return cipher_unmask_.UnMask(model);
289 #else
290 return false;
291 #endif
292 }
293
set_unmasked(bool unmasked)294 void Executor::set_unmasked(bool unmasked) { unmasked_ = unmasked; }
295
unmasked() const296 bool Executor::unmasked() const {
297 std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
298 if (encrypt_type == ps::kPWEncryptType) {
299 return unmasked_.load();
300 } else {
301 // If the algorithm of pairwise encrypt is not enabled, consider_ unmasked flag as true.
302 return true;
303 }
304 }
305
GetTrainableParamName(const CNodePtr & cnode)306 std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
307 MS_EXCEPTION_IF_NULL(cnode);
308 std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
309 if (kNameToIdxMap.count(cnode_name) == 0) {
310 return "";
311 }
312 const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name);
313 size_t weight_idx = index_info.at("inputs").at(kWeight);
314 AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first;
315 MS_EXCEPTION_IF_NULL(weight_node);
316 if (!weight_node->isa<Parameter>()) {
317 MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
318 return "";
319 }
320 return weight_node->fullname_with_scope();
321 }
322
InitParamAggregator(const FuncGraphPtr & func_graph)323 bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
324 MS_EXCEPTION_IF_NULL(func_graph);
325 const auto &cnodes = func_graph->GetOrderedCnodes();
326 for (const auto &cnode : cnodes) {
327 MS_EXCEPTION_IF_NULL(cnode);
328 const std::string ¶m_name = GetTrainableParamName(cnode);
329 if (param_name.empty()) {
330 continue;
331 }
332 if (param_aggrs_.count(param_name) != 0) {
333 MS_LOG(WARNING) << param_name << " already has parameter aggregator registered.";
334 continue;
335 }
336
337 std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>();
338 MS_EXCEPTION_IF_NULL(param_aggr);
339 param_names_.push_back(param_name);
340 param_aggrs_[param_name] = param_aggr;
341 parameter_mutex_[param_name];
342 if (!param_aggr->Init(cnode, aggregation_count_)) {
343 MS_LOG(EXCEPTION) << "Initializing parameter aggregator for " << param_name << " failed.";
344 return false;
345 }
346 MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
347 }
348 return true;
349 }
350 } // namespace server
351 } // namespace fl
352 } // namespace mindspore
353