• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_OPTIMIZER_INFO_H_
18 #define MINDSPORE_CCSRC_PS_OPTIMIZER_INFO_H_
19 
20 #include <vector>
21 #include <string>
22 #include "backend/kernel_compiler/kernel.h"
23 #include "ps/constants.h"
24 
25 namespace mindspore {
26 namespace ps {
27 using mindspore::kernel::AddressPtr;
28 class OptimizerInfo {
29  public:
30   OptimizerInfo() = default;
31   virtual ~OptimizerInfo() = default;
32 
Update(const Values & values,const Lengths & lengths)33   virtual void Update(const Values &values, const Lengths &lengths) {}
34   virtual void Accumulate(const Values &values, const Lengths &lengths) = 0;
ComputeMean(const std::vector<std::vector<size_t>> & shapes,size_t n,size_t server_num,size_t rank_id)35   virtual void ComputeMean(const std::vector<std::vector<size_t>> &shapes, size_t n, size_t server_num,
36                            size_t rank_id) {}
Reset()37   virtual void Reset() {}
38   void AddWorkspace(const AddressPtr &workspace);
39 
40   virtual const AddressPtr &gradient() = 0;
41   virtual const AddressPtr &indices() = 0;
42   virtual const size_t indice_size() const;
43   const std::vector<AddressPtr> &inputs() const;
44   const std::vector<AddressPtr> &workspaces() const;
45   const std::vector<AddressPtr> &outputs() const;
46 
47   virtual bool IsSparse() const;
48   virtual size_t grad_index();
49   virtual size_t indices_index();
50 
51  protected:
52   template <typename T>
53   void UpdateOptimInputValue(const std::string &optim_type, const std::string &input_name, void *data,
54                              const Lengths &lens);
55   std::vector<AddressPtr> inputs_;
56   std::vector<AddressPtr> workspaces_;
57   std::vector<AddressPtr> outputs_;
58 };
59 
60 class DenseOptimInfo : public OptimizerInfo {
61  public:
62   DenseOptimInfo() = default;
63   ~DenseOptimInfo() override = default;
64 
65   void Accumulate(const Values &values, const Lengths &lens) override;
66   void ComputeMean(const std::vector<std::vector<size_t>> &shapes, size_t n, size_t server_num,
67                    size_t rank_id) override;
68   void Reset() override;
69 };
70 
71 class SparseOptimInfo : public OptimizerInfo {
72  public:
73   SparseOptimInfo() = default;
74   ~SparseOptimInfo() override = default;
75 
76   void Accumulate(const Values &values, const Lengths &lens) override;
77   void ComputeMean(const std::vector<std::vector<size_t>> &shapes, size_t n, size_t server_num,
78                    size_t rank_id) override;
79   void Reset() override;
80   const size_t indice_size() const override;
81 
82  protected:
83   size_t grads_offset_{0};
84   size_t indices_offset_{0};
85   bool sharded_{true};
86 };
87 
88 class MomentumOptimInfo : public DenseOptimInfo {
89  public:
90   MomentumOptimInfo(const AddressPtr &weight, const AddressPtr &accumulate, const AddressPtr &learning_rate,
91                     const AddressPtr &gradient, const AddressPtr &momentum);
92   ~MomentumOptimInfo() override = default;
93 
94   void Update(const Values &values, const Lengths &lens) override;
95   const AddressPtr &gradient();
96   const AddressPtr &indices();
97   size_t grad_index() override;
98 };
99 
100 class SparseAdamOptimInfo : public SparseOptimInfo {
101  public:
102   SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power,
103                       const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1,
104                       const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad,
105                       const AddressPtr &indices, bool sharded);
106   ~SparseAdamOptimInfo() override = default;
107 
108   void Update(const Values &values, const Lengths &lens) override;
109   const AddressPtr &gradient();
110   const AddressPtr &indices();
111   bool IsSparse() const override;
112   size_t grad_index() override;
113   size_t indices_index() override;
114 };
115 
116 class SparseFtrlOptimInfo : public SparseOptimInfo {
117  public:
118   SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear,
119                       const AddressPtr &grad, const AddressPtr &indices, bool sharded);
120   ~SparseFtrlOptimInfo() override = default;
121 
122   const AddressPtr &gradient();
123   const AddressPtr &indices();
124   bool IsSparse() const override;
125   size_t grad_index() override;
126   size_t indices_index() override;
127 };
128 }  // namespace ps
129 }  // namespace mindspore
130 #endif  // MINDSPORE_CCSRC_PS_OPTIMIZER_INFO_H_
131