• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <utility>
25 #include "fl/server/common.h"
26 #include "fl/server/memory_register.h"
27 #include "fl/server/kernel/aggregation_kernel_factory.h"
28 #include "fl/server/kernel/optimizer_kernel_factory.h"
29 
30 namespace mindspore {
31 namespace fl {
32 namespace server {
33 // Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server
34 // kernels.
35 typedef struct {
36   std::vector<AddressPtr> inputs;
37   std::vector<AddressPtr> workspace;
38   std::vector<AddressPtr> outputs;
39 } KernelParams;
40 
41 // ParameterAggregator includes methods for aggregating gradients and optimizing weights(launching aggregation and
42 // optimizer kernels), getting weights, etc. It's not thread-safe, which means the caller must acquire lock before
43 // calling ParameterAggregator methods concurrently.
44 
45 // Each ParameterAggregator is corresponding to one weight for now.
46 
47 // ParameterAggregator is stateful because the process of aggregation and optimizing could be stateful.
48 // For example, the finite-state machine for the ParameterAggregator in parameter server training mode is below:
49 // Initial->Aggregating->Aggregation done->Optimizing->Optimizing done->Pulling->Pull done->Initial.
50 class ParameterAggregator {
51  public:
ParameterAggregator()52   ParameterAggregator()
53       : server_mode_(ServerMode::PARAMETER_SERVER),
54         required_push_count_(0),
55         required_pull_count_(0),
56         current_pull_count_(0),
57         aggregation_done_(false),
58         optimizing_done_(false),
59         pulling_done_(true),
60         memory_register_(nullptr),
61         requires_aggr_(true) {}
62   ~ParameterAggregator() = default;
63 
64   // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now.
65   // The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful.
66   bool Init(const CNodePtr &cnode, size_t threshold_count = 0);
67 
68   // Reinitialize the parameter aggregator after scaling operations are done.
69   bool ReInitForScaling();
70 
71   // After hyper-parameters are updated, some parameter aggregators should be reinitialized.
72   bool ReInitForUpdatingHyperParams(size_t aggr_threshold);
73 
74   // Update old data stored in ParameterAggregator with new data.
75   // The data could have many meanings: weights, gradients, learning_rate, momentum, etc.
76   bool UpdateData(const std::map<std::string, Address> &new_data);
77 
78   // Launch aggregators/optimizers of this ParameterAggregator in order.
79   bool LaunchAggregators();
80   bool LaunchOptimizers();
81 
82   // The implementation for primitive Pull in parameter server training mode.
83   // Every call of this method will increase the count for pull by 1.
84   AddressPtr Pull();
85 
86   // Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing
87   // any change of status.
88   AddressPtr GetWeight();
89 
90   // After aggregation/optimizing/pulling of one iteration is done, caller must reset the status to ensure the
91   // correctness of the aggregation/optimizing/pulling for next iteration.
92   void ResetAggregationStatus();
93   void ResetOptimizingStatus();
94   void ResetPullingStatus();
95 
96   // Returns the aggregation/optimizing/pulling status to the caller.
97   bool IsAggregationDone() const;
98   bool IsOptimizingDone() const;
99   bool IsPullingDone() const;
100 
101   // Return whether this parameter requires aggragation.
102   bool requires_aggr() const;
103 
104  private:
105   // Initializing aggregation/optimizer kenerls based on the cnode. The reason of this is described in the file
106   // kernel/kernel_factory.h.
107   bool InitAggregationKernels(const CNodePtr &cnode);
108   bool InitOptimizerKernels(const CNodePtr &cnode);
109 
110   // Assign memory for server kernel K(AggregationKernel/OptimizerKernel).
111   // The memory assigned can be accessed by MemoryRegister. The memory could be weights, gradients, learning_rate,
112   // momentum, etc.
113   template <typename K>
114   bool AssignMemory(const K server_kernel, const CNodePtr &cnode,
115                     const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
116                     const std::shared_ptr<MemoryRegister> &memory_register);
117 
118   // Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in
119   // memory_register.
120   bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
121                                        const std::shared_ptr<MemoryRegister> &memory_register);
122   bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> &optim_kernel,
123                                      const std::shared_ptr<MemoryRegister> &memory_register);
124 
125   // The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user
126   // configuration, etc.
127   std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode);
128 
129   // Judge whether the parameter needs to be aggregated.
130   bool JudgeRequiresAggr(const CNodePtr &cnode);
131 
132   ServerMode server_mode_;
133   size_t required_push_count_;
134   size_t required_pull_count_;
135   size_t current_pull_count_;
136 
137   // The status of aggregation/optimizing/pulling.
138   bool aggregation_done_;
139   bool optimizing_done_;
140   bool pulling_done_;
141 
142   // ParameterAggregator stores all data that it needs for aggregation, optimizing, etc.
143   std::shared_ptr<MemoryRegister> memory_register_;
144 
145   // Update could have multiple aggregation and optimizer server kernels.
146   // Here stores multiple pairs of server kernels to parameters of their Launch function.
147   std::vector<std::pair<std::shared_ptr<kernel::AggregationKernel>, KernelParams>> aggregation_kernel_parameters_;
148   std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
149 
150   // Whether this parameter needs to be aggregated.
151   bool requires_aggr_;
152 };
153 }  // namespace server
154 }  // namespace fl
155 }  // namespace mindspore
156 #endif  // MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
157