• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/parameter_aggregator.h"
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <utility>
23 #include <algorithm>
24 
25 namespace mindspore {
26 namespace fl {
27 namespace server {
Init(const CNodePtr & cnode,size_t threshold_count)28 bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
29   MS_EXCEPTION_IF_NULL(cnode);
30   memory_register_ = std::make_shared<MemoryRegister>();
31   MS_EXCEPTION_IF_NULL(memory_register_);
32 
33   required_push_count_ = threshold_count;
34   // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_.
35   // required_pull_count_ normally used in parameter server training mode.
36   required_pull_count_ = threshold_count;
37 
38   MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
39   if (!InitAggregationKernels(cnode)) {
40     MS_LOG(EXCEPTION) << "Initializing aggregation kernels failed.";
41     return false;
42   }
43   if (!InitOptimizerKernels(cnode)) {
44     MS_LOG(EXCEPTION) << "Initializing optimizer kernels failed.";
45     return false;
46   }
47   return true;
48 }
49 
ReInitForScaling()50 bool ParameterAggregator::ReInitForScaling() {
51   auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
52                              [](auto aggregation_kernel) {
53                                MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
54                                return !aggregation_kernel.first->ReInitForScaling();
55                              });
56   if (result != aggregation_kernel_parameters_.end()) {
57     MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
58     return false;
59   }
60   return true;
61 }
62 
ReInitForUpdatingHyperParams(size_t aggr_threshold)63 bool ParameterAggregator::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
64   required_push_count_ = aggr_threshold;
65   required_pull_count_ = aggr_threshold;
66   auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
67                              [aggr_threshold](auto aggregation_kernel) {
68                                MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
69                                return !aggregation_kernel.first->ReInitForUpdatingHyperParams(aggr_threshold);
70                              });
71   if (result != aggregation_kernel_parameters_.end()) {
72     MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
73     return false;
74   }
75   return true;
76 }
77 
UpdateData(const std::map<std::string,Address> & new_data)78 bool ParameterAggregator::UpdateData(const std::map<std::string, Address> &new_data) {
79   std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
80   for (const auto &data : new_data) {
81     const std::string &name = data.first;
82     if (name_to_addr.count(name) == 0) {
83       continue;
84     }
85 
86     MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name], false);
87     MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name]->addr, false);
88     MS_ERROR_IF_NULL_W_RET_VAL(data.second.addr, false);
89     MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size
90                   << ". Source size: " << data.second.size;
91     int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size);
92     if (ret != 0) {
93       MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
94       return false;
95     }
96   }
97   return true;
98 }
99 
LaunchAggregators()100 bool ParameterAggregator::LaunchAggregators() {
101   for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
102     KernelParams &params = aggregator_with_params.second;
103     std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
104     MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
105     bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs);
106     if (!ret) {
107       MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed.";
108       return false;
109     }
110   }
111   return true;
112 }
113 
LaunchOptimizers()114 bool ParameterAggregator::LaunchOptimizers() {
115   for (auto &optimizer_with_params : optimizer_kernel_parameters_) {
116     KernelParams &params = optimizer_with_params.second;
117     std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel = optimizer_with_params.first;
118     MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
119     bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs);
120     if (!ret) {
121       MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed.";
122       continue;
123     }
124   }
125   // As long as all the optimizer kernels are launched, consider optimizing for this ParameterAggregator as done.
126   optimizing_done_ = true;
127   return true;
128 }
129 
Pull()130 AddressPtr ParameterAggregator::Pull() {
131   if (memory_register_ == nullptr) {
132     MS_LOG(ERROR)
133       << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
134     return nullptr;
135   }
136 
137   current_pull_count_++;
138   if (current_pull_count_ == required_pull_count_) {
139     pulling_done_ = true;
140   }
141   MS_LOG(DEBUG) << "The " << current_pull_count_ << " time of Pull. Pulling done status: " << pulling_done_;
142 
143   std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
144   return name_to_addr["weight"];
145 }
146 
GetWeight()147 AddressPtr ParameterAggregator::GetWeight() {
148   if (memory_register_ == nullptr) {
149     MS_LOG(ERROR)
150       << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
151     return nullptr;
152   }
153   std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
154   return name_to_addr["weight"];
155 }
156 
ResetAggregationStatus()157 void ParameterAggregator::ResetAggregationStatus() {
158   for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
159     std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
160     if (aggr_kernel == nullptr) {
161       MS_LOG(ERROR) << "The aggregation kernel is nullptr.";
162       continue;
163     }
164     aggr_kernel->Reset();
165   }
166   return;
167 }
168 
ResetOptimizingStatus()169 void ParameterAggregator::ResetOptimizingStatus() { optimizing_done_ = false; }
170 
ResetPullingStatus()171 void ParameterAggregator::ResetPullingStatus() {
172   pulling_done_ = false;
173   current_pull_count_ = 0;
174 }
175 
IsAggregationDone() const176 bool ParameterAggregator::IsAggregationDone() const {
177   // Only consider aggregation done after each aggregation kernel is done.
178   for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
179     std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
180     MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
181     if (!aggr_kernel->IsAggregationDone()) {
182       return false;
183     }
184   }
185   return true;
186 }
187 
IsOptimizingDone() const188 bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; }
189 
IsPullingDone() const190 bool ParameterAggregator::IsPullingDone() const { return pulling_done_; }
191 
requires_aggr() const192 bool ParameterAggregator::requires_aggr() const { return requires_aggr_; }
193 
InitAggregationKernels(const CNodePtr & cnode)194 bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
195   MS_EXCEPTION_IF_NULL(cnode);
196   if (!JudgeRequiresAggr(cnode)) {
197     MS_LOG(WARNING) << "Aggregation for weight for kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required.";
198   }
199 
200   std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
201   for (const std::string &name : aggr_kernel_names) {
202     auto aggr_kernel = kernel::AggregationKernelFactory::GetInstance().Create(name, cnode);
203     if (aggr_kernel == nullptr) {
204       MS_LOG(EXCEPTION) << "Fail to create aggregation kernel " << name << " for " << AnfAlgo::GetCNodeName(cnode);
205       return false;
206     }
207 
208     // set_done_count must be called before InitKernel because InitKernel may use this count.
209     aggr_kernel->set_done_count(required_push_count_);
210     aggr_kernel->InitKernel(cnode);
211 
212     const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = aggr_kernel->reuse_kernel_node_inputs_info();
213     if (!AssignMemory(aggr_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
214       MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
215       return false;
216     }
217 
218     if (!GenerateAggregationKernelParams(aggr_kernel, memory_register_)) {
219       MS_LOG(EXCEPTION) << "Generating aggregation kernel parameters for " << name << " failed.";
220       return false;
221     }
222   }
223   return true;
224 }
225 
InitOptimizerKernels(const CNodePtr & cnode)226 bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
227   if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
228       ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
229     MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
230     return true;
231   }
232   MS_EXCEPTION_IF_NULL(cnode);
233   const std::string &name = AnfAlgo::GetCNodeName(cnode);
234   auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode);
235   if (optimizer_kernel == nullptr) {
236     MS_LOG(EXCEPTION) << "Failed to create optimizer kernel for " << name;
237     return false;
238   }
239 
240   optimizer_kernel->InitKernel(cnode);
241 
242   const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = optimizer_kernel->reuse_kernel_node_inputs_info();
243   if (!AssignMemory(optimizer_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
244     MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
245     return false;
246   }
247 
248   if (!GenerateOptimizerKernelParams(optimizer_kernel, memory_register_)) {
249     MS_LOG(ERROR) << "Generating optimizer kernel parameters failed.";
250     return false;
251   }
252   return true;
253 }
254 
255 template <typename K>
AssignMemory(const K server_kernel,const CNodePtr & cnode,const ReuseKernelNodeInfo & reuse_kernel_node_inputs_info,const std::shared_ptr<MemoryRegister> & memory_register)256 bool ParameterAggregator::AssignMemory(const K server_kernel, const CNodePtr &cnode,
257                                        const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
258                                        const std::shared_ptr<MemoryRegister> &memory_register) {
259   MS_EXCEPTION_IF_NULL(server_kernel);
260   MS_EXCEPTION_IF_NULL(cnode);
261   MS_EXCEPTION_IF_NULL(memory_register);
262 
263   const std::vector<std::string> &input_names = server_kernel->input_names();
264   const std::vector<size_t> &input_size_list = server_kernel->GetInputSizeList();
265   if (input_names.size() != input_size_list.size()) {
266     MS_LOG(EXCEPTION) << "Server kernel " << typeid(server_kernel.get()).name()
267                       << " input number is not matched: input_names size is " << input_names.size()
268                       << ", input_size_list size is " << input_size_list.size();
269     return false;
270   }
271 
272   if (reuse_kernel_node_inputs_info.size() > input_names.size()) {
273     MS_LOG(EXCEPTION) << "The reuse kernel node information number is invalid: got "
274                       << reuse_kernel_node_inputs_info.size() << ", but input_names size is " << input_names.size();
275     return false;
276   }
277 
278   for (size_t i = 0; i < input_names.size(); i++) {
279     const std::string &name = input_names[i];
280     if (memory_register->addresses().count(name) != 0) {
281       MS_LOG(DEBUG) << "The memory for " << name << " is already assigned.";
282       continue;
283     }
284     if (reuse_kernel_node_inputs_info.count(name) != 0) {
285       // Reusing memory of the kernel node means the memory of the input is already assigned by the front end, which
286       // is to say, the input node is a parameter node.
287       size_t index = reuse_kernel_node_inputs_info.at(name);
288       MS_LOG(INFO) << "Try to reuse memory of kernel node " << AnfAlgo::GetCNodeName(cnode) << " for parameter " << name
289                    << ", kernel node index " << index;
290       AddressPtr input_addr = GenerateParameterNodeAddrPtr(cnode, index);
291       MS_EXCEPTION_IF_NULL(input_addr);
292       memory_register->RegisterAddressPtr(name, input_addr);
293     } else {
294       MS_LOG(INFO) << "Assign new memory for " << name;
295       auto input_addr = std::make_unique<char[]>(input_size_list[i]);
296       MS_EXCEPTION_IF_NULL(input_addr);
297       memory_register->RegisterArray(name, &input_addr, input_size_list[i]);
298     }
299   }
300   return true;
301 }
302 
GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> & aggr_kernel,const std::shared_ptr<MemoryRegister> & memory_register)303 bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
304                                                           const std::shared_ptr<MemoryRegister> &memory_register) {
305   MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
306   MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
307   KernelParams aggr_params = {};
308 
309   const std::vector<std::string> &input_names = aggr_kernel->input_names();
310   (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs),
311                        [&](const std::string &name) { return memory_register->addresses()[name]; });
312 
313   const std::vector<std::string> &workspace_names = aggr_kernel->workspace_names();
314   (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace),
315                        [&](const std::string &name) { return memory_register->addresses()[name]; });
316 
317   const std::vector<std::string> &output_names = aggr_kernel->output_names();
318   (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
319                        [&](const std::string &name) { return memory_register->addresses()[name]; });
320 
321   aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs);
322   aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params));
323   return true;
324 }
325 
GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> & optimizer_kernel,const std::shared_ptr<MemoryRegister> & memory_register)326 bool ParameterAggregator::GenerateOptimizerKernelParams(
327   const std::shared_ptr<kernel::OptimizerKernel> &optimizer_kernel,
328   const std::shared_ptr<MemoryRegister> &memory_register) {
329   MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
330   MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
331   KernelParams optimizer_params = {};
332 
333   const std::vector<std::string> &input_names = optimizer_kernel->input_names();
334   (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs),
335                        [&](const std::string &name) { return memory_register->addresses()[name]; });
336 
337   const std::vector<std::string> &workspace_names = optimizer_kernel->workspace_names();
338   (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace),
339                        [&](const std::string &name) { return memory_register->addresses()[name]; });
340 
341   const std::vector<std::string> &output_names = optimizer_kernel->output_names();
342   (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs),
343                        [&](const std::string &name) { return memory_register->addresses()[name]; });
344 
345   optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params));
346   return true;
347 }
348 
SelectAggregationAlgorithm(const CNodePtr &)349 std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
350   std::vector<std::string> aggregation_algorithm = {};
351   if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
352       ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
353     (void)aggregation_algorithm.emplace_back("FedAvg");
354   } else if (ps::PSContext::instance()->server_mode() == ps::kServerModePS) {
355     (void)aggregation_algorithm.emplace_back("DenseGradAccum");
356   } else {
357     MS_LOG(EXCEPTION) << "Server doesn't support mode " << ps::PSContext::instance()->server_mode();
358     return aggregation_algorithm;
359   }
360 
361   MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
362   return aggregation_algorithm;
363 }
364 
JudgeRequiresAggr(const CNodePtr & cnode)365 bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) {
366   MS_EXCEPTION_IF_NULL(cnode);
367   std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
368   if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
369       kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) {
370     MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name;
371     return false;
372   }
373   size_t cnode_weight_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("weight");
374   auto weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, cnode_weight_idx), 0).first;
375   MS_EXCEPTION_IF_NULL(weight_node);
376 
377   if (!weight_node->isa<Parameter>()) {
378     MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter node.";
379     return false;
380   }
381   auto param_info = weight_node->cast<ParameterPtr>()->param_info();
382   MS_EXCEPTION_IF_NULL(param_info);
383   requires_aggr_ = param_info->requires_aggr();
384   return requires_aggr_;
385 }
386 
387 template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel,
388                                                 const CNodePtr &cnode,
389                                                 const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
390                                                 const std::shared_ptr<MemoryRegister> &memory_register);
391 
392 template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::AggregationKernel> server_kernel,
393                                                 const CNodePtr &cnode,
394                                                 const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
395                                                 const std::shared_ptr<MemoryRegister> &memory_register);
396 }  // namespace server
397 }  // namespace fl
398 }  // namespace mindspore
399