• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "plugin/device/cpu/kernel/fused_ada_factor_cpu_kernel.h"
17 #include <functional>
18 #include <algorithm>
19 #include "mindspore/core/ops/fused_ada_factor.h"
20 #include "plugin/device/cpu/hal/device/cpu_device_address.h"
21 
22 namespace mindspore {
23 namespace kernel {
24 namespace {
25 constexpr size_t kSizeFloat32 = sizeof(float);
26 constexpr size_t kSizeFloat16 = sizeof(float16);
27 constexpr size_t kScalarIndex = 0;
28 constexpr size_t kStandardInputNum = 12;
29 constexpr size_t kWorkSpaceNum = 3;
30 constexpr size_t kBatchSize = 1000;
31 constexpr size_t kLastRowIndex = 1;
32 constexpr size_t kLastColIndex = 2;
33 constexpr float kEps = 1e-30;
34 constexpr size_t kEpsIndex = 0;
35 constexpr size_t kClipThresholdIndex = 1;
36 constexpr size_t kBeta1Index = 2;
37 constexpr size_t kBeta2tIndex = 3;
38 constexpr size_t kWeightDecayIndex = 4;
39 constexpr size_t kLearningRateIndex = 5;
40 constexpr size_t kGradIndex = 6;
41 constexpr size_t kParamIndex = 7;
42 constexpr size_t kExpAvgIndex = 8;
43 constexpr size_t kExpAvgSQRowIndex = 9;
44 constexpr size_t kExpAvgSQColIndex = 10;
45 constexpr size_t kExpAvgSQIndex = 11;
46 constexpr size_t kGlobalNormIndex = 12;
47 constexpr size_t kWorkSpaceUpdateIndex = 0;
48 constexpr size_t kWorkSpaceRFactorIndex = 1;
49 constexpr size_t kWorkSpaceCFactorIndex = 2;
50 auto constexpr kEnableScaleParameter = "enable_scale_parameter";
51 auto constexpr kEnableFirstMoment = "enable_first_moment";
52 auto constexpr kEnableWeightDecay = "enable_weight_decay";
53 }  // namespace
54 
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)55 bool FusedAdaFactorCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
56                                       const std::vector<KernelTensor *> &outputs) {
57   param_dtype_ = inputs[kParamIndex]->dtype_id();
58 
59   enable_scale_parameter_ = GetValue<bool>(primitive_->GetAttr(kEnableScaleParameter));
60   enable_first_moment_ = GetValue<bool>(primitive_->GetAttr(kEnableFirstMoment));
61   enable_weight_decay_ = GetValue<bool>(primitive_->GetAttr(kEnableWeightDecay));
62   return true;
63 }
64 
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)65 int FusedAdaFactorCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
66                                        const std::vector<KernelTensor *> &outputs) {
67   auto ret = KernelMod::Resize(inputs, outputs);
68   if (ret != 0) {
69     return ret;
70   }
71 
72   auto shape = inputs[kParamIndex]->GetShapeVector();
73   elem_num_ = std::accumulate(shape.begin(), shape.end(), 1UL, std::multiplies<size_t>());
74   if (elem_num_ < 1) {
75     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the elem num of 'param' can not be zero.";
76   }
77   if (shape.size() >= kLastColIndex) {
78     need_factor_ = true;
79     last_row_dim_size_ = LongToSize(shape[shape.size() - kLastRowIndex]);
80     last_col_dim_size_ = LongToSize(shape[shape.size() - kLastColIndex]);
81     if (last_row_dim_size_ < 1 || last_col_dim_size_ < 1) {
82       MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of 'param' can not be zero.";
83     }
84   }
85 
86   workspace_size_list_.clear();
87   (void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat32);
88   (void)workspace_size_list_.emplace_back((elem_num_ / last_row_dim_size_) * kSizeFloat32);
89   (void)workspace_size_list_.emplace_back((elem_num_ / last_col_dim_size_) * kSizeFloat32);
90   return KRET_OK;
91 }
92 
93 template <typename T>
CalcRMS(const T * input,size_t elem_num) const94 float FusedAdaFactorCpuKernelMod::CalcRMS(const T *input, size_t elem_num) const {
95   if (elem_num == 0 || input == nullptr) {
96     return 0.0f;
97   }
98   auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
99   size_t thread_num =
100     elem_num < kBatchSize * max_thread_num ? (elem_num + kBatchSize - 1) / kBatchSize : max_thread_num;
101   std::vector<common::Task> tasks;
102   size_t batch_size = (elem_num + thread_num - 1) / thread_num;
103   std::vector<float> block_sum(thread_num, 0.0f);
104   for (size_t thread_id = 0; thread_id < thread_num; ++thread_id) {
105     size_t start = batch_size * thread_id;
106     size_t end = (start + batch_size) > elem_num ? elem_num : (start + batch_size);
107     auto block = [&, start, end, thread_id]() {
108       float square_sum = 0;
109       for (size_t i = start; i < end; ++i) {
110         auto tmp = static_cast<float>(input[i]);
111         square_sum += tmp * tmp;
112       }
113       block_sum[thread_id] = square_sum;
114       return common::SUCCESS;
115     };
116     (void)tasks.emplace_back(block);
117   }
118   ParallelLaunch(tasks);
119   auto rms = std::accumulate(block_sum.begin(), block_sum.end(), 0.0f);
120   rms = rms / elem_num;
121   return std::sqrt(rms);
122 }
123 
124 template <typename T>
FactorUpdate(float * update,const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspaces) const125 void FusedAdaFactorCpuKernelMod::FactorUpdate(float *update, const std::vector<KernelTensor *> &inputs,
126                                               const std::vector<KernelTensor *> &workspaces) const {
127   auto beta2t = reinterpret_cast<float *>(inputs[kBeta2tIndex]->device_ptr())[kScalarIndex];
128   auto grad = reinterpret_cast<T *>(inputs[kGradIndex]->device_ptr());
129   auto exp_avg_sq_row = reinterpret_cast<T *>(inputs[kExpAvgSQRowIndex]->device_ptr());
130   auto exp_avg_sq_col = reinterpret_cast<T *>(inputs[kExpAvgSQColIndex]->device_ptr());
131   auto r_factor = reinterpret_cast<float *>(workspaces[kWorkSpaceRFactorIndex]->device_ptr());
132   auto c_factor = reinterpret_cast<float *>(workspaces[kWorkSpaceCFactorIndex]->device_ptr());
133   auto one_minus_beta2t = 1 - beta2t;
134 
135   std::function<void(size_t, size_t)> task;
136   size_t exp_avg_sq_row_elem_num = elem_num_ / last_row_dim_size_;
137   size_t exp_avg_sq_col_elem_num = elem_num_ / last_col_dim_size_;
138   size_t last_row_col_size = last_row_dim_size_ * last_col_dim_size_;
139   size_t row_dim_size = last_row_dim_size_;
140   size_t col_dim_size = last_col_dim_size_;
141   // calc exp_avg_sq_row
142   task = [&](size_t start, size_t end) {
143     for (size_t i = start; i < end; ++i) {
144       float row_reduce = 0;
145       size_t reduce_start = i * row_dim_size;
146       for (size_t j = 0; j < row_dim_size; ++j) {
147         row_reduce += update[reduce_start + j];
148       }
149       row_reduce = row_reduce / row_dim_size;
150       auto tmp = static_cast<float>(exp_avg_sq_row[i]) * beta2t + row_reduce * one_minus_beta2t;
151       exp_avg_sq_row[i] = static_cast<T>(tmp);
152     }
153   };
154   CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num, kBatchSize);
155 
156   // calc r_factor
157   task = [&](size_t start, size_t end) {
158     for (size_t i = start; i < end; ++i) {
159       float col_reduce = 0;
160       size_t reduce_start = i * col_dim_size;
161       for (size_t j = 0; j < col_dim_size; ++j) {
162         col_reduce += static_cast<float>(exp_avg_sq_row[reduce_start + j]);
163       }
164       col_reduce = col_reduce / col_dim_size;
165       col_reduce = std::max(col_reduce, kEps);
166       for (size_t j = 0; j < col_dim_size; ++j) {
167         r_factor[reduce_start + j] = std::sqrt(static_cast<float>(exp_avg_sq_row[reduce_start + j]) / col_reduce);
168       }
169     }
170   };
171   CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num / col_dim_size, kBatchSize);
172 
173   // calc exp_avg_sq_col and c_factor
174   task = [&](size_t start, size_t end) {
175     for (size_t i = start; i < end; ++i) {
176       float row_reduce = 0;
177       size_t reduce_start = (i / row_dim_size) * last_row_col_size + i % row_dim_size;
178       for (size_t j = 0; j < col_dim_size; ++j) {
179         row_reduce += update[reduce_start + j * row_dim_size];
180       }
181       row_reduce = row_reduce / col_dim_size;
182       auto tmp = static_cast<float>(exp_avg_sq_col[i]) * beta2t + row_reduce * one_minus_beta2t;
183       tmp = std::max(tmp, kEps);
184       exp_avg_sq_col[i] = static_cast<T>(tmp);
185       c_factor[i] = std::sqrt(tmp);
186     }
187   };
188   CPUKernelUtils::ParallelFor(task, exp_avg_sq_col_elem_num, kBatchSize);
189 
190   // calc update
191   task = [&, this](size_t start, size_t end) {
192     for (size_t i = start; i < end; ++i) {
193       size_t row_i = i % row_dim_size;
194       size_t col_i = i / row_dim_size % col_dim_size;
195       size_t slice = i / last_row_col_size;
196       auto norm = r_factor[slice * col_dim_size + col_i] * c_factor[slice * row_dim_size + row_i];
197       update[i] = static_cast<float>(grad[i]) * global_norm_reciprocal_ / std::max(norm, kEps);
198     }
199   };
200   CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
201 }
202 
203 template <typename T>
LaunchKernel(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspaces,const std::vector<KernelTensor * > &)204 void FusedAdaFactorCpuKernelMod::LaunchKernel(const std::vector<KernelTensor *> &inputs,
205                                               const std::vector<KernelTensor *> &workspaces,
206                                               const std::vector<KernelTensor *> &) {
207   auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->device_ptr());
208   auto clip_threshold = reinterpret_cast<float *>(inputs[kClipThresholdIndex]->device_ptr())[kScalarIndex];
209   auto beta1 = reinterpret_cast<float *>(inputs[kBeta1Index]->device_ptr())[kScalarIndex];
210   auto beta2t = reinterpret_cast<float *>(inputs[kBeta2tIndex]->device_ptr())[kScalarIndex];
211   auto weight_decay = reinterpret_cast<float *>(inputs[kWeightDecayIndex]->device_ptr())[kScalarIndex];
212   auto learning_rate = reinterpret_cast<float *>(inputs[kLearningRateIndex]->device_ptr())[kScalarIndex];
213   auto grad = reinterpret_cast<T *>(inputs[kGradIndex]->device_ptr());
214   auto param = reinterpret_cast<T *>(inputs[kParamIndex]->device_ptr());
215   auto exp_avg = reinterpret_cast<T *>(inputs[kExpAvgIndex]->device_ptr());
216   auto exp_avg_sq = reinterpret_cast<T *>(inputs[kExpAvgSQIndex]->device_ptr());
217   auto update = reinterpret_cast<float *>(workspaces[kWorkSpaceUpdateIndex]->device_ptr());
218   auto one_minus_beta1 = 1 - beta1;
219   auto one_minus_beta2t = 1 - beta2t;
220   if (clip_threshold <= 0) {
221     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', clip threshold " << clip_threshold << " is invalid. ";
222   }
223   if (beta1 < 0 || one_minus_beta1 < 0) {
224     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', beta1 " << beta1 << " is invalid. ";
225   }
226   if (beta2t < 0 || one_minus_beta2t < 0) {
227     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', beta2t " << beta2t << " is invalid. ";
228   }
229   if (epsilon[0] < 0 || epsilon[1] < 0) {
230     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', epsilon (" << epsilon[0] << "," << epsilon[1]
231                       << ") is invalid. ";
232   }
233 
234   std::function<void(size_t, size_t)> task;
235   // calc update
236   task = [&, this](size_t start, size_t end) {
237     for (size_t i = start; i < end; ++i) {
238       auto tmp = static_cast<float>(grad[i]) * global_norm_reciprocal_;
239       update[i] = tmp * tmp + epsilon[0];
240     }
241   };
242   CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
243 
244   if (need_factor_) {
245     FactorUpdate<T>(update, inputs, workspaces);
246   } else {
247     // no factor
248     task = [&, this](size_t start, size_t end) {
249       for (size_t i = start; i < end; ++i) {
250         auto tmp = static_cast<float>(exp_avg_sq[i]) * beta2t + update[i] * one_minus_beta2t;
251         tmp = std::max(tmp, kEps);
252         exp_avg_sq[i] = static_cast<T>(tmp);
253         update[i] = static_cast<float>(grad[i]) * global_norm_reciprocal_ / std::sqrt(tmp);
254       }
255     };
256     CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
257   }
258 
259   // scale learning rate with rms of param
260   if (enable_scale_parameter_) {
261     auto rms = CalcRMS(param, elem_num_);
262     learning_rate = learning_rate * std::max(epsilon[1], rms);
263   }
264 
265   // update param
266   auto update_rms = CalcRMS(update, elem_num_);
267   auto update_rms_threshold = update_rms / clip_threshold;
268   auto update_coff = learning_rate / std::max(update_rms_threshold, 1.0f);
269   task = [&, this](size_t start, size_t end) {
270     for (size_t i = start; i < end; ++i) {
271       update[i] = update[i] * update_coff;
272       if (enable_first_moment_) {
273         update[i] = static_cast<float>(exp_avg[i]) * beta1 + update[i] * one_minus_beta1;
274         exp_avg[i] = static_cast<T>(update[i]);
275       }
276       if (enable_weight_decay_) {
277         auto tmp = update[i] + static_cast<float>(param[i]) * weight_decay * learning_rate;
278         param[i] = static_cast<T>(static_cast<float>(param[i]) - tmp);
279       } else {
280         param[i] = static_cast<T>(static_cast<float>(param[i]) - update[i]);
281       }
282     }
283   };
284   CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
285 }
286 
Launch(const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > & workspaces,const std::vector<kernel::KernelTensor * > & outputs)287 bool FusedAdaFactorCpuKernelMod::Launch(const std::vector<kernel::KernelTensor *> &inputs,
288                                         const std::vector<kernel::KernelTensor *> &workspaces,
289                                         const std::vector<kernel::KernelTensor *> &outputs) {
290   if (inputs.size() == kStandardInputNum + 1) {
291     auto global_norm = reinterpret_cast<float *>(inputs[kGlobalNormIndex]->device_ptr())[kScalarIndex];
292     if (global_norm < kEps) {
293       global_norm_reciprocal_ = 1.0f;
294     } else {
295       global_norm_reciprocal_ = 1.0f / global_norm;
296     }
297   }
298 
299   CheckInputAddresses(inputs);
300   CheckWorkspaceAddresses(workspaces);
301   if (param_dtype_ == kNumberTypeFloat16) {
302     LaunchKernel<float16>(inputs, workspaces, outputs);
303   } else {
304     LaunchKernel<float>(inputs, workspaces, outputs);
305   }
306   return true;
307 }
308 
CheckInputAddresses(const std::vector<kernel::KernelTensor * > & inputs) const309 void FusedAdaFactorCpuKernelMod::CheckInputAddresses(const std::vector<kernel::KernelTensor *> &inputs) const {
310   if (inputs.size() < kStandardInputNum) {
311     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be at least " << kStandardInputNum
312                       << ", but got: " << inputs.size();
313   }
314 
315   if (inputs[kEpsIndex]->size() != kSizeFloat32 << 1) {
316     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'epsilon' must be " << (kSizeFloat32 << 1)
317                       << ", but got " << inputs[kEpsIndex]->size();
318   }
319   if (inputs[kClipThresholdIndex]->size() != kSizeFloat32) {
320     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'clip_threshold' must be " << kSizeFloat32
321                       << ", but got " << inputs[kClipThresholdIndex]->size();
322   }
323   if (inputs[kBeta1Index]->size() != kSizeFloat32) {
324     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' must be " << kSizeFloat32
325                       << ", but got " << inputs[kBeta1Index]->size();
326   }
327   if (inputs[kBeta2tIndex]->size() != kSizeFloat32) {
328     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta2t' must be " << kSizeFloat32
329                       << ", but got " << inputs[kBeta2tIndex]->size();
330   }
331   if (inputs[kWeightDecayIndex]->size() != kSizeFloat32) {
332     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'weight_decay' must be " << kSizeFloat32
333                       << ", but got " << inputs[kWeightDecayIndex]->size();
334   }
335   if (inputs[kLearningRateIndex]->size() != kSizeFloat32) {
336     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'lr' must be " << kSizeFloat32
337                       << ", but got " << inputs[kLearningRateIndex]->size();
338   }
339 
340   size_t param_size = param_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem_num_ * kSizeFloat32;
341   if (inputs[kParamIndex]->size() != param_size) {
342     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'param' must be " << param_size
343                       << ", but got " << inputs[kParamIndex]->size();
344   }
345   if (inputs[kGradIndex]->size() != param_size) {
346     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' must be " << param_size
347                       << ", but got " << inputs[kGradIndex]->size();
348   }
349 
350   if (enable_first_moment_ && inputs[kExpAvgIndex]->size() != param_size) {
351     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg' must be " << param_size
352                       << ", but got " << inputs[kExpAvgIndex]->size();
353   }
354 
355   if (!need_factor_) {
356     if (inputs[kExpAvgSQIndex]->size() != param_size) {
357       MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq' must be " << param_size
358                         << ", but got " << inputs[kExpAvgSQIndex]->size();
359     }
360     return;
361   }
362 
363   if (inputs[kExpAvgSQRowIndex]->size() != param_size / last_row_dim_size_) {
364     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq_row' must be "
365                       << param_size / last_row_dim_size_ << ", but got " << inputs[kExpAvgSQRowIndex]->size();
366   }
367   if (inputs[kExpAvgSQColIndex]->size() != param_size / last_col_dim_size_) {
368     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq_col' must be "
369                       << param_size / last_col_dim_size_ << ", but got " << inputs[kExpAvgSQColIndex]->size();
370   }
371 }
372 
CheckWorkspaceAddresses(const std::vector<kernel::KernelTensor * > & workspaces) const373 void FusedAdaFactorCpuKernelMod::CheckWorkspaceAddresses(const std::vector<kernel::KernelTensor *> &workspaces) const {
374   if (workspaces.size() != kWorkSpaceNum) {
375     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of workspaces must be " << kWorkSpaceNum
376                       << ", but got: " << workspaces.size();
377   }
378 
379   size_t update_size = elem_num_ * kSizeFloat32;
380   if (workspaces[kWorkSpaceUpdateIndex]->size() != elem_num_ * kSizeFloat32) {
381     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'update ' must be " << update_size
382                       << ", but got " << workspaces[kWorkSpaceUpdateIndex]->size();
383   }
384 
385   if (workspaces[kWorkSpaceRFactorIndex]->size() != update_size / last_row_dim_size_) {
386     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'r_factor' must be "
387                       << update_size / last_row_dim_size_ << ", but got " << workspaces[kWorkSpaceRFactorIndex]->size();
388   }
389   if (workspaces[kWorkSpaceCFactorIndex]->size() != update_size / last_col_dim_size_) {
390     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'c_factor' must be "
391                       << update_size / last_col_dim_size_ << ", but got " << workspaces[kWorkSpaceCFactorIndex]->size();
392   }
393 }
394 
GetOpSupport()395 std::vector<KernelAttr> FusedAdaFactorCpuKernelMod::GetOpSupport() {
396   static std::map<std::string, std::vector<KernelAttr>> support_list_map = {{kFusedAdaFactor,
397                                                                              {KernelAttr()
398                                                                                 .AddInputAttr(kNumberTypeFloat32)
399                                                                                 .AddInputAttr(kNumberTypeFloat32)
400                                                                                 .AddInputAttr(kNumberTypeFloat32)
401                                                                                 .AddInputAttr(kNumberTypeFloat32)
402                                                                                 .AddInputAttr(kNumberTypeFloat32)
403                                                                                 .AddInputAttr(kNumberTypeFloat32)
404                                                                                 .AddInputAttr(kNumberTypeFloat32)
405                                                                                 .AddInputAttr(kNumberTypeFloat32)
406                                                                                 .AddInputAttr(kNumberTypeFloat32)
407                                                                                 .AddInputAttr(kNumberTypeFloat32)
408                                                                                 .AddInputAttr(kNumberTypeFloat32)
409                                                                                 .AddInputAttr(kNumberTypeFloat32)
410                                                                                 .AddOutputAttr(kNumberTypeFloat32),
411                                                                               KernelAttr()
412                                                                                 .AddInputAttr(kNumberTypeFloat32)
413                                                                                 .AddInputAttr(kNumberTypeFloat32)
414                                                                                 .AddInputAttr(kNumberTypeFloat32)
415                                                                                 .AddInputAttr(kNumberTypeFloat32)
416                                                                                 .AddInputAttr(kNumberTypeFloat32)
417                                                                                 .AddInputAttr(kNumberTypeFloat32)
418                                                                                 .AddInputAttr(kNumberTypeFloat16)
419                                                                                 .AddInputAttr(kNumberTypeFloat16)
420                                                                                 .AddInputAttr(kNumberTypeFloat16)
421                                                                                 .AddInputAttr(kNumberTypeFloat16)
422                                                                                 .AddInputAttr(kNumberTypeFloat16)
423                                                                                 .AddInputAttr(kNumberTypeFloat16)
424                                                                                 .AddOutputAttr(kNumberTypeFloat16)}},
425                                                                             {kFusedAdaFactorWithGlobalNorm,
426                                                                              {KernelAttr()
427                                                                                 .AddInputAttr(kNumberTypeFloat32)
428                                                                                 .AddInputAttr(kNumberTypeFloat32)
429                                                                                 .AddInputAttr(kNumberTypeFloat32)
430                                                                                 .AddInputAttr(kNumberTypeFloat32)
431                                                                                 .AddInputAttr(kNumberTypeFloat32)
432                                                                                 .AddInputAttr(kNumberTypeFloat32)
433                                                                                 .AddInputAttr(kNumberTypeFloat32)
434                                                                                 .AddInputAttr(kNumberTypeFloat32)
435                                                                                 .AddInputAttr(kNumberTypeFloat32)
436                                                                                 .AddInputAttr(kNumberTypeFloat32)
437                                                                                 .AddInputAttr(kNumberTypeFloat32)
438                                                                                 .AddInputAttr(kNumberTypeFloat32)
439                                                                                 .AddInputAttr(kNumberTypeFloat32)
440                                                                                 .AddOutputAttr(kNumberTypeFloat32),
441                                                                               KernelAttr()
442                                                                                 .AddInputAttr(kNumberTypeFloat32)
443                                                                                 .AddInputAttr(kNumberTypeFloat32)
444                                                                                 .AddInputAttr(kNumberTypeFloat32)
445                                                                                 .AddInputAttr(kNumberTypeFloat32)
446                                                                                 .AddInputAttr(kNumberTypeFloat32)
447                                                                                 .AddInputAttr(kNumberTypeFloat32)
448                                                                                 .AddInputAttr(kNumberTypeFloat16)
449                                                                                 .AddInputAttr(kNumberTypeFloat16)
450                                                                                 .AddInputAttr(kNumberTypeFloat16)
451                                                                                 .AddInputAttr(kNumberTypeFloat16)
452                                                                                 .AddInputAttr(kNumberTypeFloat16)
453                                                                                 .AddInputAttr(kNumberTypeFloat16)
454                                                                                 .AddInputAttr(kNumberTypeFloat32)
455                                                                                 .AddOutputAttr(kNumberTypeFloat16)}}};
456   auto iter = support_list_map.find(kernel_type_);
457   if (iter == support_list_map.end()) {
458     MS_LOG(EXCEPTION) << "Does not support " << kernel_type_ << "!";
459   }
460 
461   return iter->second;
462 }
463 
464 MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FusedAdaFactor,
__anona3ea948f0a02() 465                                  []() { return std::make_shared<FusedAdaFactorCpuKernelMod>(kFusedAdaFactor); });
__anona3ea948f0b02() 466 MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FusedAdaFactorWithGlobalNorm, []() {
467   return std::make_shared<FusedAdaFactorCpuKernelMod>(kFusedAdaFactorWithGlobalNorm);
468 });
469 }  // namespace kernel
470 }  // namespace mindspore
471