• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/executor.h"
18 #include <set>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 namespace mindspore {
24 namespace fl {
25 namespace server {
Initialize(const FuncGraphPtr & func_graph,size_t aggregation_count)26 void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
27   MS_EXCEPTION_IF_NULL(func_graph);
28   if (aggregation_count == 0) {
29     MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
30     return;
31   }
32   aggregation_count_ = aggregation_count;
33 
34   // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers.
35   bool ret = InitParamAggregator(func_graph);
36   if (!ret) {
37     MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
38     return;
39   }
40   initialized_ = true;
41   return;
42 }
43 
ReInitForScaling()44 bool Executor::ReInitForScaling() {
45   auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(),
46                              [](auto param_aggr) { return !param_aggr.second->ReInitForScaling(); });
47   if (result != param_aggrs_.end()) {
48     MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
49     return false;
50   }
51   return true;
52 }
53 
ReInitForUpdatingHyperParams(size_t aggr_threshold)54 bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
55   aggregation_count_ = aggr_threshold;
56   auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(), [this](auto param_aggr) {
57     return !param_aggr.second->ReInitForUpdatingHyperParams(aggregation_count_);
58   });
59   if (result != param_aggrs_.end()) {
60     MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
61     return false;
62   }
63   return true;
64 }
65 
initialized() const66 bool Executor::initialized() const { return initialized_; }
67 
HandlePush(const std::string & param_name,const UploadData & upload_data)68 bool Executor::HandlePush(const std::string &param_name, const UploadData &upload_data) {
69   MS_LOG(DEBUG) << "Do Push for parameter " << param_name;
70   if (param_aggrs_.count(param_name) == 0) {
71     MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
72     return false;
73   }
74 
75   std::mutex &mtx = parameter_mutex_[param_name];
76   std::unique_lock<std::mutex> lock(mtx);
77   auto &param_aggr = param_aggrs_[param_name];
78   MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
79   // Push operation needs to wait until the pulling process is done.
80   while (!param_aggr->IsPullingDone()) {
81     lock.unlock();
82     std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
83     lock.lock();
84   }
85 
86   // 1.Update data with the uploaded data of the worker.
87   if (!param_aggr->UpdateData(upload_data)) {
88     MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
89     return false;
90   }
91   // 2.Launch aggregation for this trainable parameter.
92   if (!param_aggr->LaunchAggregators()) {
93     MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
94     return false;
95   }
96   if (param_aggr->IsAggregationDone()) {
97     // 3.After the aggregation is done, optimize the trainable parameter.
98     if (!param_aggr->LaunchOptimizers()) {
99       MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed.";
100       return false;
101     }
102     // 4.Reset pulling and aggregation status after optimizing is done.
103     param_aggr->ResetPullingStatus();
104     param_aggr->ResetAggregationStatus();
105   }
106   return true;
107 }
108 
HandleModelUpdate(const std::string & param_name,const UploadData & upload_data)109 bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData &upload_data) {
110   MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
111   if (param_aggrs_.count(param_name) == 0) {
112     // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we
113     // just print a warning log and return true.
114     MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
115     return true;
116   }
117 
118   std::mutex &mtx = parameter_mutex_[param_name];
119   std::unique_lock<std::mutex> lock(mtx);
120   auto &param_aggr = param_aggrs_[param_name];
121   MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
122   if (!param_aggr->UpdateData(upload_data)) {
123     MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
124     return false;
125   }
126   // Different from Push, UpdateModel doesn't need to checkout the aggregation status.
127   if (!param_aggr->LaunchAggregators()) {
128     MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
129     return false;
130   }
131   return true;
132 }
133 
HandleModelUpdateAsync(const std::map<std::string,UploadData> & feature_map)134 bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map) {
135   std::unique_lock<std::mutex> model_lock(model_mutex_);
136   for (const auto &trainable_param : feature_map) {
137     const std::string &param_name = trainable_param.first;
138     if (param_aggrs_.count(param_name) == 0) {
139       MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
140       continue;
141     }
142 
143     std::mutex &mtx = parameter_mutex_[param_name];
144     std::unique_lock<std::mutex> lock(mtx);
145     auto &param_aggr = param_aggrs_[param_name];
146     MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
147     const UploadData &upload_data = trainable_param.second;
148     if (!param_aggr->UpdateData(upload_data)) {
149       MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
150       return false;
151     }
152     if (!param_aggr->LaunchAggregators()) {
153       MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
154       return false;
155     }
156   }
157   return true;
158 }
159 
HandlePushWeight(const std::map<std::string,Address> & feature_map)160 bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
161   for (const auto &trainable_param : feature_map) {
162     const std::string &param_name = trainable_param.first;
163     if (param_aggrs_.count(param_name) == 0) {
164       MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server.";
165       continue;
166     }
167 
168     std::mutex &mtx = parameter_mutex_[param_name];
169     std::unique_lock<std::mutex> lock(mtx);
170     auto &param_aggr = param_aggrs_[param_name];
171     MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
172     AddressPtr old_weight = param_aggr->GetWeight();
173     const Address &new_weight = trainable_param.second;
174     MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false);
175     MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false);
176     MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false);
177     int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
178     if (ret != 0) {
179       MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
180       return false;
181     }
182   }
183   return true;
184 }
185 
HandlePull(const std::string & param_name)186 AddressPtr Executor::HandlePull(const std::string &param_name) {
187   MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
188   if (param_aggrs_.count(param_name) == 0) {
189     MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
190     return nullptr;
191   }
192 
193   std::mutex &mtx = parameter_mutex_[param_name];
194   std::unique_lock<std::mutex> lock(mtx);
195   auto &param_aggr = param_aggrs_[param_name];
196   MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, nullptr);
197   // Pulling must wait until the optimizing process is done.
198   while (!param_aggr->IsOptimizingDone()) {
199     lock.unlock();
200     std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
201     lock.lock();
202   }
203   AddressPtr addr = param_aggr->Pull();
204   // If this Pull is the last one, reset pulling and optimizing status.
205   if (param_aggr->IsPullingDone()) {
206     param_aggr->ResetOptimizingStatus();
207   }
208   return addr;
209 }
210 
HandlePullWeight(const std::vector<std::string> & param_names)211 std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> &param_names) {
212   std::map<std::string, AddressPtr> weights;
213   for (const auto &param_name : param_names) {
214     if (param_aggrs_.count(param_name) == 0) {
215       MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server.";
216       return weights;
217     }
218 
219     std::mutex &mtx = parameter_mutex_[param_name];
220     std::unique_lock<std::mutex> lock(mtx);
221     const auto &param_aggr = param_aggrs_[param_name];
222     MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, weights);
223     AddressPtr addr = param_aggr->GetWeight();
224     if (addr == nullptr) {
225       MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
226       continue;
227     }
228     weights[param_name] = addr;
229   }
230   return weights;
231 }
232 
IsAllWeightAggregationDone()233 bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }
234 
IsWeightAggrDone(const std::vector<std::string> & param_names)235 bool Executor::IsWeightAggrDone(const std::vector<std::string> &param_names) {
236   for (const auto &name : param_names) {
237     if (param_aggrs_.count(name) == 0) {
238       MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
239       return false;
240     }
241 
242     std::mutex &mtx = parameter_mutex_[name];
243     std::unique_lock<std::mutex> lock(mtx);
244     auto &param_aggr = param_aggrs_[name];
245     MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
246     if (!param_aggr->requires_aggr()) {
247       continue;
248     }
249     if (!param_aggr->IsAggregationDone()) {
250       MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
251       return false;
252     }
253   }
254   return true;
255 }
256 
ResetAggregationStatus()257 void Executor::ResetAggregationStatus() {
258   for (const auto &param_name : param_names_) {
259     std::mutex &mtx = parameter_mutex_[param_name];
260     std::unique_lock<std::mutex> lock(mtx);
261     auto &param_aggr = param_aggrs_[param_name];
262     MS_ERROR_IF_NULL_WO_RET_VAL(param_aggr);
263     param_aggr->ResetAggregationStatus();
264   }
265   return;
266 }
267 
GetModel()268 std::map<std::string, AddressPtr> Executor::GetModel() {
269   std::map<std::string, AddressPtr> model = {};
270   for (const auto &name : param_names_) {
271     std::mutex &mtx = parameter_mutex_[name];
272     std::unique_lock<std::mutex> lock(mtx);
273     AddressPtr addr = param_aggrs_[name]->GetWeight();
274     if (addr == nullptr) {
275       MS_LOG(WARNING) << "Get weight of " << name << " failed.";
276       continue;
277     }
278     model[name] = addr;
279   }
280   return model;
281 }
282 
param_names() const283 const std::vector<std::string> &Executor::param_names() const { return param_names_; }
284 
Unmask()285 bool Executor::Unmask() {
286 #ifdef ENABLE_ARMOUR
287   auto model = GetModel();
288   return cipher_unmask_.UnMask(model);
289 #else
290   return false;
291 #endif
292 }
293 
set_unmasked(bool unmasked)294 void Executor::set_unmasked(bool unmasked) { unmasked_ = unmasked; }
295 
unmasked() const296 bool Executor::unmasked() const {
297   std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
298   if (encrypt_type == ps::kPWEncryptType) {
299     return unmasked_.load();
300   } else {
301     // If the algorithm of pairwise encrypt is not enabled, consider_ unmasked flag as true.
302     return true;
303   }
304 }
305 
GetTrainableParamName(const CNodePtr & cnode)306 std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
307   MS_EXCEPTION_IF_NULL(cnode);
308   std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
309   if (kNameToIdxMap.count(cnode_name) == 0) {
310     return "";
311   }
312   const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name);
313   size_t weight_idx = index_info.at("inputs").at(kWeight);
314   AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first;
315   MS_EXCEPTION_IF_NULL(weight_node);
316   if (!weight_node->isa<Parameter>()) {
317     MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
318     return "";
319   }
320   return weight_node->fullname_with_scope();
321 }
322 
InitParamAggregator(const FuncGraphPtr & func_graph)323 bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
324   MS_EXCEPTION_IF_NULL(func_graph);
325   const auto &cnodes = func_graph->GetOrderedCnodes();
326   for (const auto &cnode : cnodes) {
327     MS_EXCEPTION_IF_NULL(cnode);
328     const std::string &param_name = GetTrainableParamName(cnode);
329     if (param_name.empty()) {
330       continue;
331     }
332     if (param_aggrs_.count(param_name) != 0) {
333       MS_LOG(WARNING) << param_name << " already has parameter aggregator registered.";
334       continue;
335     }
336 
337     std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>();
338     MS_EXCEPTION_IF_NULL(param_aggr);
339     param_names_.push_back(param_name);
340     param_aggrs_[param_name] = param_aggr;
341     parameter_mutex_[param_name];
342     if (!param_aggr->Init(cnode, aggregation_count_)) {
343       MS_LOG(EXCEPTION) << "Initializing parameter aggregator for " << param_name << " failed.";
344       return false;
345     }
346     MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
347   }
348   return true;
349 }
350 }  // namespace server
351 }  // namespace fl
352 }  // namespace mindspore
353