1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/parameter_aggregator.h"
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <utility>
23 #include <algorithm>
24
25 namespace mindspore {
26 namespace fl {
27 namespace server {
Init(const CNodePtr & cnode,size_t threshold_count)28 bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
29 MS_EXCEPTION_IF_NULL(cnode);
30 memory_register_ = std::make_shared<MemoryRegister>();
31 MS_EXCEPTION_IF_NULL(memory_register_);
32
33 required_push_count_ = threshold_count;
34 // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_.
35 // required_pull_count_ normally used in parameter server training mode.
36 required_pull_count_ = threshold_count;
37
38 MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
39 if (!InitAggregationKernels(cnode)) {
40 MS_LOG(EXCEPTION) << "Initializing aggregation kernels failed.";
41 return false;
42 }
43 if (!InitOptimizerKernels(cnode)) {
44 MS_LOG(EXCEPTION) << "Initializing optimizer kernels failed.";
45 return false;
46 }
47 return true;
48 }
49
ReInitForScaling()50 bool ParameterAggregator::ReInitForScaling() {
51 auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
52 [](auto aggregation_kernel) {
53 MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
54 return !aggregation_kernel.first->ReInitForScaling();
55 });
56 if (result != aggregation_kernel_parameters_.end()) {
57 MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
58 return false;
59 }
60 return true;
61 }
62
ReInitForUpdatingHyperParams(size_t aggr_threshold)63 bool ParameterAggregator::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
64 required_push_count_ = aggr_threshold;
65 required_pull_count_ = aggr_threshold;
66 auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
67 [aggr_threshold](auto aggregation_kernel) {
68 MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
69 return !aggregation_kernel.first->ReInitForUpdatingHyperParams(aggr_threshold);
70 });
71 if (result != aggregation_kernel_parameters_.end()) {
72 MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
73 return false;
74 }
75 return true;
76 }
77
UpdateData(const std::map<std::string,Address> & new_data)78 bool ParameterAggregator::UpdateData(const std::map<std::string, Address> &new_data) {
79 std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
80 for (const auto &data : new_data) {
81 const std::string &name = data.first;
82 if (name_to_addr.count(name) == 0) {
83 continue;
84 }
85
86 MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name], false);
87 MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name]->addr, false);
88 MS_ERROR_IF_NULL_W_RET_VAL(data.second.addr, false);
89 MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size
90 << ". Source size: " << data.second.size;
91 int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size);
92 if (ret != 0) {
93 MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
94 return false;
95 }
96 }
97 return true;
98 }
99
LaunchAggregators()100 bool ParameterAggregator::LaunchAggregators() {
101 for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
102 KernelParams ¶ms = aggregator_with_params.second;
103 std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
104 MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
105 bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs);
106 if (!ret) {
107 MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed.";
108 return false;
109 }
110 }
111 return true;
112 }
113
LaunchOptimizers()114 bool ParameterAggregator::LaunchOptimizers() {
115 for (auto &optimizer_with_params : optimizer_kernel_parameters_) {
116 KernelParams ¶ms = optimizer_with_params.second;
117 std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel = optimizer_with_params.first;
118 MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
119 bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs);
120 if (!ret) {
121 MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed.";
122 continue;
123 }
124 }
125 // As long as all the optimizer kernels are launched, consider optimizing for this ParameterAggregator as done.
126 optimizing_done_ = true;
127 return true;
128 }
129
Pull()130 AddressPtr ParameterAggregator::Pull() {
131 if (memory_register_ == nullptr) {
132 MS_LOG(ERROR)
133 << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
134 return nullptr;
135 }
136
137 current_pull_count_++;
138 if (current_pull_count_ == required_pull_count_) {
139 pulling_done_ = true;
140 }
141 MS_LOG(DEBUG) << "The " << current_pull_count_ << " time of Pull. Pulling done status: " << pulling_done_;
142
143 std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
144 return name_to_addr["weight"];
145 }
146
GetWeight()147 AddressPtr ParameterAggregator::GetWeight() {
148 if (memory_register_ == nullptr) {
149 MS_LOG(ERROR)
150 << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
151 return nullptr;
152 }
153 std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
154 return name_to_addr["weight"];
155 }
156
ResetAggregationStatus()157 void ParameterAggregator::ResetAggregationStatus() {
158 for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
159 std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
160 if (aggr_kernel == nullptr) {
161 MS_LOG(ERROR) << "The aggregation kernel is nullptr.";
162 continue;
163 }
164 aggr_kernel->Reset();
165 }
166 return;
167 }
168
ResetOptimizingStatus()169 void ParameterAggregator::ResetOptimizingStatus() { optimizing_done_ = false; }
170
ResetPullingStatus()171 void ParameterAggregator::ResetPullingStatus() {
172 pulling_done_ = false;
173 current_pull_count_ = 0;
174 }
175
IsAggregationDone() const176 bool ParameterAggregator::IsAggregationDone() const {
177 // Only consider aggregation done after each aggregation kernel is done.
178 for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
179 std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
180 MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
181 if (!aggr_kernel->IsAggregationDone()) {
182 return false;
183 }
184 }
185 return true;
186 }
187
IsOptimizingDone() const188 bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; }
189
IsPullingDone() const190 bool ParameterAggregator::IsPullingDone() const { return pulling_done_; }
191
requires_aggr() const192 bool ParameterAggregator::requires_aggr() const { return requires_aggr_; }
193
InitAggregationKernels(const CNodePtr & cnode)194 bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
195 MS_EXCEPTION_IF_NULL(cnode);
196 if (!JudgeRequiresAggr(cnode)) {
197 MS_LOG(WARNING) << "Aggregation for weight for kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required.";
198 }
199
200 std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
201 for (const std::string &name : aggr_kernel_names) {
202 auto aggr_kernel = kernel::AggregationKernelFactory::GetInstance().Create(name, cnode);
203 if (aggr_kernel == nullptr) {
204 MS_LOG(EXCEPTION) << "Fail to create aggregation kernel " << name << " for " << AnfAlgo::GetCNodeName(cnode);
205 return false;
206 }
207
208 // set_done_count must be called before InitKernel because InitKernel may use this count.
209 aggr_kernel->set_done_count(required_push_count_);
210 aggr_kernel->InitKernel(cnode);
211
212 const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = aggr_kernel->reuse_kernel_node_inputs_info();
213 if (!AssignMemory(aggr_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
214 MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
215 return false;
216 }
217
218 if (!GenerateAggregationKernelParams(aggr_kernel, memory_register_)) {
219 MS_LOG(EXCEPTION) << "Generating aggregation kernel parameters for " << name << " failed.";
220 return false;
221 }
222 }
223 return true;
224 }
225
InitOptimizerKernels(const CNodePtr & cnode)226 bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
227 if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
228 ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
229 MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
230 return true;
231 }
232 MS_EXCEPTION_IF_NULL(cnode);
233 const std::string &name = AnfAlgo::GetCNodeName(cnode);
234 auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode);
235 if (optimizer_kernel == nullptr) {
236 MS_LOG(EXCEPTION) << "Failed to create optimizer kernel for " << name;
237 return false;
238 }
239
240 optimizer_kernel->InitKernel(cnode);
241
242 const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = optimizer_kernel->reuse_kernel_node_inputs_info();
243 if (!AssignMemory(optimizer_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
244 MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
245 return false;
246 }
247
248 if (!GenerateOptimizerKernelParams(optimizer_kernel, memory_register_)) {
249 MS_LOG(ERROR) << "Generating optimizer kernel parameters failed.";
250 return false;
251 }
252 return true;
253 }
254
255 template <typename K>
AssignMemory(const K server_kernel,const CNodePtr & cnode,const ReuseKernelNodeInfo & reuse_kernel_node_inputs_info,const std::shared_ptr<MemoryRegister> & memory_register)256 bool ParameterAggregator::AssignMemory(const K server_kernel, const CNodePtr &cnode,
257 const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
258 const std::shared_ptr<MemoryRegister> &memory_register) {
259 MS_EXCEPTION_IF_NULL(server_kernel);
260 MS_EXCEPTION_IF_NULL(cnode);
261 MS_EXCEPTION_IF_NULL(memory_register);
262
263 const std::vector<std::string> &input_names = server_kernel->input_names();
264 const std::vector<size_t> &input_size_list = server_kernel->GetInputSizeList();
265 if (input_names.size() != input_size_list.size()) {
266 MS_LOG(EXCEPTION) << "Server kernel " << typeid(server_kernel.get()).name()
267 << " input number is not matched: input_names size is " << input_names.size()
268 << ", input_size_list size is " << input_size_list.size();
269 return false;
270 }
271
272 if (reuse_kernel_node_inputs_info.size() > input_names.size()) {
273 MS_LOG(EXCEPTION) << "The reuse kernel node information number is invalid: got "
274 << reuse_kernel_node_inputs_info.size() << ", but input_names size is " << input_names.size();
275 return false;
276 }
277
278 for (size_t i = 0; i < input_names.size(); i++) {
279 const std::string &name = input_names[i];
280 if (memory_register->addresses().count(name) != 0) {
281 MS_LOG(DEBUG) << "The memory for " << name << " is already assigned.";
282 continue;
283 }
284 if (reuse_kernel_node_inputs_info.count(name) != 0) {
285 // Reusing memory of the kernel node means the memory of the input is already assigned by the front end, which
286 // is to say, the input node is a parameter node.
287 size_t index = reuse_kernel_node_inputs_info.at(name);
288 MS_LOG(INFO) << "Try to reuse memory of kernel node " << AnfAlgo::GetCNodeName(cnode) << " for parameter " << name
289 << ", kernel node index " << index;
290 AddressPtr input_addr = GenerateParameterNodeAddrPtr(cnode, index);
291 MS_EXCEPTION_IF_NULL(input_addr);
292 memory_register->RegisterAddressPtr(name, input_addr);
293 } else {
294 MS_LOG(INFO) << "Assign new memory for " << name;
295 auto input_addr = std::make_unique<char[]>(input_size_list[i]);
296 MS_EXCEPTION_IF_NULL(input_addr);
297 memory_register->RegisterArray(name, &input_addr, input_size_list[i]);
298 }
299 }
300 return true;
301 }
302
GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> & aggr_kernel,const std::shared_ptr<MemoryRegister> & memory_register)303 bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
304 const std::shared_ptr<MemoryRegister> &memory_register) {
305 MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
306 MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
307 KernelParams aggr_params = {};
308
309 const std::vector<std::string> &input_names = aggr_kernel->input_names();
310 (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs),
311 [&](const std::string &name) { return memory_register->addresses()[name]; });
312
313 const std::vector<std::string> &workspace_names = aggr_kernel->workspace_names();
314 (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace),
315 [&](const std::string &name) { return memory_register->addresses()[name]; });
316
317 const std::vector<std::string> &output_names = aggr_kernel->output_names();
318 (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
319 [&](const std::string &name) { return memory_register->addresses()[name]; });
320
321 aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs);
322 aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params));
323 return true;
324 }
325
GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> & optimizer_kernel,const std::shared_ptr<MemoryRegister> & memory_register)326 bool ParameterAggregator::GenerateOptimizerKernelParams(
327 const std::shared_ptr<kernel::OptimizerKernel> &optimizer_kernel,
328 const std::shared_ptr<MemoryRegister> &memory_register) {
329 MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
330 MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
331 KernelParams optimizer_params = {};
332
333 const std::vector<std::string> &input_names = optimizer_kernel->input_names();
334 (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs),
335 [&](const std::string &name) { return memory_register->addresses()[name]; });
336
337 const std::vector<std::string> &workspace_names = optimizer_kernel->workspace_names();
338 (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace),
339 [&](const std::string &name) { return memory_register->addresses()[name]; });
340
341 const std::vector<std::string> &output_names = optimizer_kernel->output_names();
342 (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs),
343 [&](const std::string &name) { return memory_register->addresses()[name]; });
344
345 optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params));
346 return true;
347 }
348
SelectAggregationAlgorithm(const CNodePtr &)349 std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
350 std::vector<std::string> aggregation_algorithm = {};
351 if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
352 ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
353 (void)aggregation_algorithm.emplace_back("FedAvg");
354 } else if (ps::PSContext::instance()->server_mode() == ps::kServerModePS) {
355 (void)aggregation_algorithm.emplace_back("DenseGradAccum");
356 } else {
357 MS_LOG(EXCEPTION) << "Server doesn't support mode " << ps::PSContext::instance()->server_mode();
358 return aggregation_algorithm;
359 }
360
361 MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
362 return aggregation_algorithm;
363 }
364
JudgeRequiresAggr(const CNodePtr & cnode)365 bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) {
366 MS_EXCEPTION_IF_NULL(cnode);
367 std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
368 if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
369 kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) {
370 MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name;
371 return false;
372 }
373 size_t cnode_weight_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("weight");
374 auto weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, cnode_weight_idx), 0).first;
375 MS_EXCEPTION_IF_NULL(weight_node);
376
377 if (!weight_node->isa<Parameter>()) {
378 MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter node.";
379 return false;
380 }
381 auto param_info = weight_node->cast<ParameterPtr>()->param_info();
382 MS_EXCEPTION_IF_NULL(param_info);
383 requires_aggr_ = param_info->requires_aggr();
384 return requires_aggr_;
385 }
386
387 template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel,
388 const CNodePtr &cnode,
389 const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
390 const std::shared_ptr<MemoryRegister> &memory_register);
391
392 template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::AggregationKernel> server_kernel,
393 const CNodePtr &cnode,
394 const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
395 const std::shared_ptr<MemoryRegister> &memory_register);
396 } // namespace server
397 } // namespace fl
398 } // namespace mindspore
399