1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_OPTIMIZER_INFO_H_ 18 #define MINDSPORE_CCSRC_PS_OPTIMIZER_INFO_H_ 19 20 #include <vector> 21 #include <string> 22 #include "backend/kernel_compiler/kernel.h" 23 #include "ps/constants.h" 24 25 namespace mindspore { 26 namespace ps { 27 using mindspore::kernel::AddressPtr; 28 class OptimizerInfo { 29 public: 30 OptimizerInfo() = default; 31 virtual ~OptimizerInfo() = default; 32 Update(const Values & values,const Lengths & lengths)33 virtual void Update(const Values &values, const Lengths &lengths) {} 34 virtual void Accumulate(const Values &values, const Lengths &lengths) = 0; ComputeMean(const std::vector<std::vector<size_t>> & shapes,size_t n,size_t server_num,size_t rank_id)35 virtual void ComputeMean(const std::vector<std::vector<size_t>> &shapes, size_t n, size_t server_num, 36 size_t rank_id) {} Reset()37 virtual void Reset() {} 38 void AddWorkspace(const AddressPtr &workspace); 39 40 virtual const AddressPtr &gradient() = 0; 41 virtual const AddressPtr &indices() = 0; 42 virtual const size_t indice_size() const; 43 const std::vector<AddressPtr> &inputs() const; 44 const std::vector<AddressPtr> &workspaces() const; 45 const std::vector<AddressPtr> &outputs() const; 46 47 virtual bool IsSparse() const; 48 virtual size_t grad_index(); 49 virtual size_t indices_index(); 50 51 protected: 52 template <typename T> 53 void UpdateOptimInputValue(const std::string &optim_type, const std::string &input_name, void *data, 54 const Lengths &lens); 55 std::vector<AddressPtr> inputs_; 56 std::vector<AddressPtr> workspaces_; 57 std::vector<AddressPtr> outputs_; 58 }; 59 60 class DenseOptimInfo : public OptimizerInfo { 61 public: 62 DenseOptimInfo() = default; 63 ~DenseOptimInfo() override = default; 64 65 void Accumulate(const Values &values, const Lengths &lens) override; 66 void ComputeMean(const std::vector<std::vector<size_t>> &shapes, size_t n, size_t server_num, 67 size_t rank_id) override; 68 void Reset() override; 69 }; 70 71 class SparseOptimInfo : public OptimizerInfo { 72 public: 73 SparseOptimInfo() = default; 74 ~SparseOptimInfo() override = default; 75 76 void Accumulate(const Values &values, const Lengths &lens) override; 77 void ComputeMean(const std::vector<std::vector<size_t>> &shapes, size_t n, size_t server_num, 78 size_t rank_id) override; 79 void Reset() override; 80 const size_t indice_size() const override; 81 82 protected: 83 size_t grads_offset_{0}; 84 size_t indices_offset_{0}; 85 bool sharded_{true}; 86 }; 87 88 class MomentumOptimInfo : public DenseOptimInfo { 89 public: 90 MomentumOptimInfo(const AddressPtr &weight, const AddressPtr &accumulate, const AddressPtr &learning_rate, 91 const AddressPtr &gradient, const AddressPtr &momentum); 92 ~MomentumOptimInfo() override = default; 93 94 void Update(const Values &values, const Lengths &lens) override; 95 const AddressPtr &gradient(); 96 const AddressPtr &indices(); 97 size_t grad_index() override; 98 }; 99 100 class SparseAdamOptimInfo : public SparseOptimInfo { 101 public: 102 SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power, 103 const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1, 104 const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, 105 const AddressPtr &indices, bool sharded); 106 ~SparseAdamOptimInfo() override = default; 107 108 void Update(const Values &values, const Lengths &lens) override; 109 const AddressPtr &gradient(); 110 const AddressPtr &indices(); 111 bool IsSparse() const override; 112 size_t grad_index() override; 113 size_t indices_index() override; 114 }; 115 116 class SparseFtrlOptimInfo : public SparseOptimInfo { 117 public: 118 SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, 119 const AddressPtr &grad, const AddressPtr &indices, bool sharded); 120 ~SparseFtrlOptimInfo() override = default; 121 122 const AddressPtr &gradient(); 123 const AddressPtr &indices(); 124 bool IsSparse() const override; 125 size_t grad_index() override; 126 size_t indices_index() override; 127 }; 128 } // namespace ps 129 } // namespace mindspore 130 #endif // MINDSPORE_CCSRC_PS_OPTIMIZER_INFO_H_ 131