• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include "backend/kernel_compiler/cpu/cpu_kernel.h"
24 #include "fl/server/common.h"
25 #include "fl/server/memory_register.h"
26 #include "fl/server/kernel/params_info.h"
27 
28 namespace mindspore {
29 namespace fl {
30 namespace server {
31 namespace kernel {
32 // AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation.
33 // For example, dense gradients accumulation, federated average, etc.
34 // Normally the aggregation process in AggregationKernel is like a finite-state machine:
35 // Initial->Aggregating->Aggregation done->Initial.
36 class AggregationKernel : public CPUKernel {
37  public:
AggregationKernel()38   AggregationKernel() : name_(""), done_(false), done_count_(0), accum_count_(0) {}
39   virtual ~AggregationKernel() = default;
40 
41   // InitKernel and Launch methods are inherited from pure virtual function of CPUKernel so it must have implementation.
InitKernel(const CNodePtr & kernel_node)42   virtual void InitKernel(const CNodePtr &kernel_node) {}
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & workspace,const std::vector<AddressPtr> & outputs)43   virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
44                       const std::vector<AddressPtr> &outputs) {
45     return true;
46   }
47 
48   // Server kernel's memory allocation method, which is different from the workflow in
49   // Session(GPUSession/CPUSession/AscendSession).
50   // virtual void AssignMemory(const CNodePtr &kernel_node, std::shared_ptr<MemoryRegister> memory_register) = 0;
51 
52   // Set the cumulative count this aggregation kernel needs before aggregation is done.
set_done_count(size_t count)53   void set_done_count(size_t count) { done_count_ = count; }
54 
55   // So we use Reset to set the finite-state machine state to Initial after considering this round of aggregation is
56   // done.
57   virtual void Reset() = 0;
58 
59   virtual bool IsAggregationDone() = 0;
60 
61   // Some kernels should know the inputs/workspace/outputs addresses at initializing phase. For example, FedAvgKernel.
SetParameterAddress(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & workspace,const std::vector<AddressPtr> & outputs)62   virtual void SetParameterAddress(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
63                                    const std::vector<AddressPtr> &outputs) {
64     return;
65   }
66 
67   // Reinitialize aggregation kernel after scaling operations are done.
ReInitForScaling()68   virtual bool ReInitForScaling() { return true; }
69 
ReInitForUpdatingHyperParams(size_t)70   virtual bool ReInitForUpdatingHyperParams(size_t) { return true; }
71 
72   // Setter and getter of kernels parameters information.
set_params_info(const ParamsInfo & params_info)73   void set_params_info(const ParamsInfo &params_info) { params_info_ = params_info; }
input_names()74   const std::vector<std::string> &input_names() { return params_info_.inputs_names(); }
workspace_names()75   const std::vector<std::string> &workspace_names() { return params_info_.workspace_names(); }
output_names()76   const std::vector<std::string> &output_names() { return params_info_.outputs_names(); }
77 
78   // Returns information about whether some inputs should reuse kernel node inputs memory.
reuse_kernel_node_inputs_info()79   const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info() { return reuse_kernel_node_inputs_info_; }
80 
81  protected:
82   virtual void GenerateReuseKernelNodeInfo() = 0;
83   // Aggregation kernel's name which is set by kernel register function.
84   std::string name_;
85 
86   // The aggregation is considered done after done_count_ times of accumulation.
87   bool done_;
88 
89   // Cumulative count this aggregation kernel needs before aggregation is done.
90   size_t done_count_;
91 
92   // Current cumulative count.
93   size_t accum_count_;
94 
95   // Parameters information used for kernel register, memory assignment, etc.
96   ParamsInfo params_info_;
97 
98   // Information about server kernel reusing kernel node inputs memory from the front end.
99   // Key refers to the server kernel's input index. Value refers to the kernel node's input index.
100   ReuseKernelNodeInfo reuse_kernel_node_inputs_info_;
101 };
102 }  // namespace kernel
103 }  // namespace server
104 }  // namespace fl
105 }  // namespace mindspore
106 #endif  // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
107