• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/optimizer_info_builder.h"
18 #include <vector>
19 #include <memory>
20 #include <functional>
21 #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
22 
23 namespace mindspore {
24 namespace ps {
25 using mindspore::kernel::ps::SparseApplyFtrlPSKernel;
Build(const std::shared_ptr<PServerKernel> & pserver_kernel,const WeightPtr & weight,const Keys & keys,const Values & values,const Lengths & lens,const InputsShapePtr & inputs_shape,size_t worker_num,bool sharded)26 OptimizerInfo *OptimizerInfoBuilder::Build(const std::shared_ptr<PServerKernel> &pserver_kernel,
27                                            const WeightPtr &weight, const Keys &keys, const Values &values,
28                                            const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num,
29                                            bool sharded) {
30   MS_EXCEPTION_IF_NULL(pserver_kernel);
31   MS_EXCEPTION_IF_NULL(weight);
32   MS_EXCEPTION_IF_NULL(inputs_shape);
33   OptimizerInfo *optim_info =
34     BuildInputs(weight, keys, values, lens, inputs_shape, worker_num, pserver_kernel, sharded);
35   MS_EXCEPTION_IF_NULL(optim_info);
36   std::vector<size_t> ws_sizes = pserver_kernel->workspace_sizes();
37   BuildWorkspaces(optim_info, ws_sizes, worker_num);
38   BuildOutputs(optim_info, worker_num);
39   return optim_info;
40 }
41 
BuildWorkspaces(OptimizerInfo * info,const std::vector<size_t> & ws_sizes,size_t)42 void OptimizerInfoBuilder::BuildWorkspaces(OptimizerInfo *info, const std::vector<size_t> &ws_sizes, size_t) {
43   MS_EXCEPTION_IF_NULL(info);
44   for (size_t i = 0; i < ws_sizes.size(); i++) {
45     size_t size = ws_sizes[i];
46     AddressPtr workspace = std::make_shared<kernel::Address>();
47     MS_EXCEPTION_IF_NULL(workspace);
48     workspace->addr = new float[size];
49     MS_EXCEPTION_IF_NULL(workspace->addr);
50     workspace->size = size;
51     info->AddWorkspace(workspace);
52   }
53 }
54 
55 template <typename T>
GenInputAddrPtr(const std::string & optim_type,const std::string & input_name,void * ps_data,const Lengths & ps_lens,const InputsShapePtr & inputs_shape)56 AddressPtr OptimizerInfoBuilder::GenInputAddrPtr(const std::string &optim_type, const std::string &input_name,
57                                                  void *ps_data, const Lengths &ps_lens,
58                                                  const InputsShapePtr &inputs_shape) {
59   MS_EXCEPTION_IF_NULL(ps_data);
60   // Take note of that the data type maybe inconsistent in ps_data.
61   MS_LOG(INFO) << "Get input address pointer for optimizer:" << optim_type << ", input name:" << input_name;
62   AddressPtr addr_ptr = std::make_shared<kernel::Address>();
63   MS_EXCEPTION_IF_NULL(addr_ptr);
64 
65   if (kOptimToOriginIdx.count(optim_type) == 0 || kOptimToPSSendIdx.count(optim_type) == 0) {
66     MS_LOG(EXCEPTION) << "Optimizer type " << optim_type << " in not supported.";
67   }
68   const OptimOriginIdx &origin_input_map = kOptimToOriginIdx.at(optim_type);
69   const OptimPSSendIdx &ps_send_index_map = kOptimToPSSendIdx.at(optim_type);
70   if (ps_send_index_map.count(input_name) == 0 || origin_input_map.count(input_name) == 0) {
71     MS_LOG(EXCEPTION) << "Optimizer " << optim_type << " has no input for " << input_name;
72   }
73   size_t ps_index = ps_send_index_map.at(input_name);
74   if (ps_index == INDEX_NOT_SEND) {
75     MS_LOG(EXCEPTION) << "Input " << input_name << " is not supposed to be sent to PS.";
76   }
77 
78   size_t addr_data_size, addr_data_offset;
79   if (inputs_shape != nullptr) {
80     // addr_data_size should be calculated by inputs_shape if it's passed.
81     size_t origin_index = origin_input_map.at(input_name);
82     EXC_IF_VEC_IDX_OOB((*inputs_shape), origin_index);
83     MS_EXCEPTION_IF_NULL((*inputs_shape)[origin_index]);
84     auto shape = *((*inputs_shape)[origin_index]);
85     addr_data_size = std::accumulate(shape.begin(), shape.end(), worker_num_, std::multiplies<size_t>());
86   } else {
87     EXC_IF_VEC_IDX_OOB(ps_lens, ps_index);
88     addr_data_size = IntToSize(ps_lens[ps_index]);
89   }
90   addr_data_offset =
91     IntToSize(std::accumulate(ps_lens.begin(), ps_lens.begin() + SizeToInt(ps_index), 0, std::plus<int>()));
92 
93   // The size in ps_lens instead of addr_data_size is the size of real data.
94   T *buffer = new T[addr_data_size];
95   addr_ptr->size = IntToSize(ps_lens[ps_index]) * sizeof(T);
96   addr_ptr->addr = buffer;
97 
98   size_t dst_size = addr_ptr->size;
99   size_t src_size = addr_ptr->size;
100   void *dst_data = addr_ptr->addr;
101   void *src_data = reinterpret_cast<T *>(ps_data) + addr_data_offset;
102   MS_EXCEPTION_IF_NULL(dst_data);
103   MS_EXCEPTION_IF_NULL(src_data);
104   int64_t ret = memcpy_s(dst_data, dst_size, src_data, src_size);
105   if (ret != 0) {
106     MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
107     delete[] buffer;
108     buffer = nullptr;
109     return nullptr;
110   }
111   return addr_ptr;
112 }
113 
BuildInputs(const WeightPtr & weight,const Keys &,const Values & values,const Lengths & lens,const InputsShapePtr &,size_t,const std::shared_ptr<PServerKernel> &,bool)114 OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &, const Values &values,
115                                                      const Lengths &lens, const InputsShapePtr &, size_t,
116                                                      const std::shared_ptr<PServerKernel> &, bool) {
117   MS_EXCEPTION_IF_NULL(weight);
118   AddressPtr weight_addr = std::make_shared<kernel::Address>();
119   MS_EXCEPTION_IF_NULL(weight_addr);
120   weight_addr->addr = weight->data();
121   weight_addr->size = weight->size() * sizeof(float);
122 
123   AddressPtr accumulate = std::make_shared<kernel::Address>();
124   MS_EXCEPTION_IF_NULL(accumulate);
125 
126   accumulate->addr = new float[weight->size()];
127   MS_EXCEPTION_IF_NULL(accumulate->addr);
128   accumulate->size = sizeof(float) * weight->size();
129   int64_t ret = memset_s(accumulate->addr, accumulate->size, 0x00, accumulate->size);
130   if (ret != 0) {
131     MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
132     delete[] reinterpret_cast<float *>(accumulate->addr);
133     accumulate->addr = nullptr;
134     return nullptr;
135   }
136 
137   AddressPtr learning_rate = GenInputAddrPtr<float>(kApplyMomentum, "lr", const_cast<float *>(values.data()), lens);
138   MS_EXCEPTION_IF_NULL(learning_rate);
139   AddressPtr gradient = GenInputAddrPtr<float>(kApplyMomentum, "grad", const_cast<float *>(values.data()), lens);
140   MS_EXCEPTION_IF_NULL(gradient);
141   AddressPtr momentum = GenInputAddrPtr<float>(kApplyMomentum, "momentum", const_cast<float *>(values.data()), lens);
142   MS_EXCEPTION_IF_NULL(momentum);
143   return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum);
144 }
145 
BuildInputs(const WeightPtr & weight,const Keys &,const Values & values,const Lengths & lens,const InputsShapePtr & inputs_shape,size_t,const std::shared_ptr<PServerKernel> &,bool sharded)146 OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &, const Values &values,
147                                                        const Lengths &lens, const InputsShapePtr &inputs_shape, size_t,
148                                                        const std::shared_ptr<PServerKernel> &, bool sharded) {
149   AddressPtr weight_addr = std::make_shared<kernel::Address>();
150   MS_EXCEPTION_IF_NULL(weight_addr);
151   weight_addr->addr = weight->data();
152   weight_addr->size = weight->size() * sizeof(float);
153 
154   AddressPtr m = std::make_shared<kernel::Address>();
155   MS_EXCEPTION_IF_NULL(m);
156 
157   m->addr = new float[weight->size()];
158   MS_EXCEPTION_IF_NULL(m->addr);
159   m->size = weight->size() * sizeof(float);
160   int64_t ret = memset_s(m->addr, m->size, 0x00, m->size);
161   if (ret != 0) {
162     MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
163     delete[] reinterpret_cast<float *>(m->addr);
164     m->addr = nullptr;
165     return nullptr;
166   }
167 
168   AddressPtr v = std::make_shared<kernel::Address>();
169   MS_EXCEPTION_IF_NULL(v);
170 
171   v->addr = new float[weight->size()];
172   MS_EXCEPTION_IF_NULL(v->addr);
173   v->size = weight->size() * sizeof(float);
174   ret = memset_s(v->addr, v->size, 0x00, v->size);
175   if (ret != 0) {
176     MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
177     delete[] reinterpret_cast<float *>(v->addr);
178     v->addr = nullptr;
179     delete[] reinterpret_cast<float *>(m->addr);
180     m->addr = nullptr;
181     return nullptr;
182   }
183 
184   AddressPtr beta1_power = GenInputAddrPtr<float>(kSparseAdam, "beta1_power", const_cast<float *>(values.data()), lens);
185   MS_EXCEPTION_IF_NULL(beta1_power);
186   AddressPtr beta2_power = GenInputAddrPtr<float>(kSparseAdam, "beta2_power", const_cast<float *>(values.data()), lens);
187   MS_EXCEPTION_IF_NULL(beta2_power);
188   AddressPtr learning_rate = GenInputAddrPtr<float>(kSparseAdam, "lr", const_cast<float *>(values.data()), lens);
189   MS_EXCEPTION_IF_NULL(learning_rate);
190   AddressPtr beta1 = GenInputAddrPtr<float>(kSparseAdam, "beta1", const_cast<float *>(values.data()), lens);
191   MS_EXCEPTION_IF_NULL(beta1);
192   AddressPtr beta2 = GenInputAddrPtr<float>(kSparseAdam, "beta2", const_cast<float *>(values.data()), lens);
193   MS_EXCEPTION_IF_NULL(beta2);
194   AddressPtr epsilon = GenInputAddrPtr<float>(kSparseAdam, "eps", const_cast<float *>(values.data()), lens);
195   MS_EXCEPTION_IF_NULL(epsilon);
196   AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", const_cast<float *>(values.data()), lens, inputs_shape);
197   MS_EXCEPTION_IF_NULL(grad);
198   AddressPtr indices =
199     GenInputAddrPtr<float>(kSparseAdam, "indices", const_cast<float *>(values.data()), lens, inputs_shape);
200   MS_EXCEPTION_IF_NULL(indices);
201   return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon,
202                                  grad, indices, sharded);
203 }
204 
BuildInputs(const WeightPtr & weight,const Keys &,const Values & values,const Lengths & lens,const InputsShapePtr & inputs_shape,size_t,const std::shared_ptr<PServerKernel> & pserver_kernel,bool sharded)205 OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &, const Values &values,
206                                                        const Lengths &lens, const InputsShapePtr &inputs_shape, size_t,
207                                                        const std::shared_ptr<PServerKernel> &pserver_kernel,
208                                                        bool sharded) {
209   MS_EXCEPTION_IF_NULL(inputs_shape);
210   AddressPtr weight_addr = std::make_shared<kernel::Address>();
211   MS_EXCEPTION_IF_NULL(weight_addr);
212   weight_addr->addr = weight->data();
213   weight_addr->size = weight->size() * sizeof(float);
214 
215   AddressPtr accum = std::make_shared<kernel::Address>();
216   MS_EXCEPTION_IF_NULL(accum);
217 
218   accum->addr = new float[weight->size()];
219   MS_EXCEPTION_IF_NULL(accum->addr);
220   accum->size = weight->size() * sizeof(float);
221   for (size_t i = 0; i < weight->size(); i++) {
222     float *tmp = reinterpret_cast<float *>(accum->addr);
223     tmp[i] = std::dynamic_pointer_cast<SparseApplyFtrlPSKernel>(pserver_kernel)->init_accum();
224   }
225 
226   AddressPtr linear = std::make_shared<kernel::Address>();
227   MS_EXCEPTION_IF_NULL(linear);
228 
229   linear->addr = new float[weight->size()];
230   MS_EXCEPTION_IF_NULL(linear->addr);
231   linear->size = weight->size() * sizeof(float);
232   int64_t ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float));
233   if (ret != 0) {
234     MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
235     delete[] reinterpret_cast<float *>(linear->addr);
236     linear->addr = nullptr;
237     return nullptr;
238   }
239 
240   AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", const_cast<float *>(values.data()), lens, inputs_shape);
241   MS_EXCEPTION_IF_NULL(grad);
242   AddressPtr indices =
243     GenInputAddrPtr<float>(kSparseFtrl, "indices", const_cast<float *>(values.data()), lens, inputs_shape);
244   MS_EXCEPTION_IF_NULL(indices);
245   return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, sharded);
246 }
247 }  // namespace ps
248 }  // namespace mindspore
249