1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 #include <utility> 25 #include "fl/server/common.h" 26 #include "fl/server/memory_register.h" 27 #include "fl/server/kernel/aggregation_kernel_factory.h" 28 #include "fl/server/kernel/optimizer_kernel_factory.h" 29 30 namespace mindspore { 31 namespace fl { 32 namespace server { 33 // Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server 34 // kernels. 35 typedef struct { 36 std::vector<AddressPtr> inputs; 37 std::vector<AddressPtr> workspace; 38 std::vector<AddressPtr> outputs; 39 } KernelParams; 40 41 // ParameterAggregator includes methods for aggregating gradients and optimizing weights(launching aggregation and 42 // optimizer kernels), getting weights, etc. It's not thread-safe, which means the caller must acquire lock before 43 // calling ParameterAggregator methods concurrently. 44 45 // Each ParameterAggregator is corresponding to one weight for now. 46 47 // ParameterAggregator is stateful because the process of aggregation and optimizing could be stateful. 48 // For example, the finite-state machine for the ParameterAggregator in parameter server training mode is below: 49 // Initial->Aggregating->Aggregation done->Optimizing->Optimizing done->Pulling->Pull done->Initial. 50 class ParameterAggregator { 51 public: ParameterAggregator()52 ParameterAggregator() 53 : server_mode_(ServerMode::PARAMETER_SERVER), 54 required_push_count_(0), 55 required_pull_count_(0), 56 current_pull_count_(0), 57 aggregation_done_(false), 58 optimizing_done_(false), 59 pulling_done_(true), 60 memory_register_(nullptr), 61 requires_aggr_(true) {} 62 ~ParameterAggregator() = default; 63 64 // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now. 65 // The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful. 66 bool Init(const CNodePtr &cnode, size_t threshold_count = 0); 67 68 // Reinitialize the parameter aggregator after scaling operations are done. 69 bool ReInitForScaling(); 70 71 // After hyper-parameters are updated, some parameter aggregators should be reinitialized. 72 bool ReInitForUpdatingHyperParams(size_t aggr_threshold); 73 74 // Update old data stored in ParameterAggregator with new data. 75 // The data could have many meanings: weights, gradients, learning_rate, momentum, etc. 76 bool UpdateData(const std::map<std::string, Address> &new_data); 77 78 // Launch aggregators/optimizers of this ParameterAggregator in order. 79 bool LaunchAggregators(); 80 bool LaunchOptimizers(); 81 82 // The implementation for primitive Pull in parameter server training mode. 83 // Every call of this method will increase the count for pull by 1. 84 AddressPtr Pull(); 85 86 // Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing 87 // any change of status. 88 AddressPtr GetWeight(); 89 90 // After aggregation/optimizing/pulling of one iteration is done, caller must reset the status to ensure the 91 // correctness of the aggregation/optimizing/pulling for next iteration. 92 void ResetAggregationStatus(); 93 void ResetOptimizingStatus(); 94 void ResetPullingStatus(); 95 96 // Returns the aggregation/optimizing/pulling status to the caller. 97 bool IsAggregationDone() const; 98 bool IsOptimizingDone() const; 99 bool IsPullingDone() const; 100 101 // Return whether this parameter requires aggragation. 102 bool requires_aggr() const; 103 104 private: 105 // Initializing aggregation/optimizer kenerls based on the cnode. The reason of this is described in the file 106 // kernel/kernel_factory.h. 107 bool InitAggregationKernels(const CNodePtr &cnode); 108 bool InitOptimizerKernels(const CNodePtr &cnode); 109 110 // Assign memory for server kernel K(AggregationKernel/OptimizerKernel). 111 // The memory assigned can be accessed by MemoryRegister. The memory could be weights, gradients, learning_rate, 112 // momentum, etc. 113 template <typename K> 114 bool AssignMemory(const K server_kernel, const CNodePtr &cnode, 115 const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, 116 const std::shared_ptr<MemoryRegister> &memory_register); 117 118 // Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in 119 // memory_register. 120 bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel, 121 const std::shared_ptr<MemoryRegister> &memory_register); 122 bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> &optim_kernel, 123 const std::shared_ptr<MemoryRegister> &memory_register); 124 125 // The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user 126 // configuration, etc. 127 std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode); 128 129 // Judge whether the parameter needs to be aggregated. 130 bool JudgeRequiresAggr(const CNodePtr &cnode); 131 132 ServerMode server_mode_; 133 size_t required_push_count_; 134 size_t required_pull_count_; 135 size_t current_pull_count_; 136 137 // The status of aggregation/optimizing/pulling. 138 bool aggregation_done_; 139 bool optimizing_done_; 140 bool pulling_done_; 141 142 // ParameterAggregator stores all data that it needs for aggregation, optimizing, etc. 143 std::shared_ptr<MemoryRegister> memory_register_; 144 145 // Update could have multiple aggregation and optimizer server kernels. 146 // Here stores multiple pairs of server kernels to parameters of their Launch function. 147 std::vector<std::pair<std::shared_ptr<kernel::AggregationKernel>, KernelParams>> aggregation_kernel_parameters_; 148 std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_; 149 150 // Whether this parameter needs to be aggregated. 151 bool requires_aggr_; 152 }; 153 } // namespace server 154 } // namespace fl 155 } // namespace mindspore 156 #endif // MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_ 157