1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "plugin/device/cpu/kernel/fused_ada_factor_cpu_kernel.h"
17 #include <functional>
18 #include <algorithm>
19 #include "mindspore/core/ops/fused_ada_factor.h"
20 #include "plugin/device/cpu/hal/device/cpu_device_address.h"
21
22 namespace mindspore {
23 namespace kernel {
24 namespace {
25 constexpr size_t kSizeFloat32 = sizeof(float);
26 constexpr size_t kSizeFloat16 = sizeof(float16);
27 constexpr size_t kScalarIndex = 0;
28 constexpr size_t kStandardInputNum = 12;
29 constexpr size_t kWorkSpaceNum = 3;
30 constexpr size_t kBatchSize = 1000;
31 constexpr size_t kLastRowIndex = 1;
32 constexpr size_t kLastColIndex = 2;
33 constexpr float kEps = 1e-30;
34 constexpr size_t kEpsIndex = 0;
35 constexpr size_t kClipThresholdIndex = 1;
36 constexpr size_t kBeta1Index = 2;
37 constexpr size_t kBeta2tIndex = 3;
38 constexpr size_t kWeightDecayIndex = 4;
39 constexpr size_t kLearningRateIndex = 5;
40 constexpr size_t kGradIndex = 6;
41 constexpr size_t kParamIndex = 7;
42 constexpr size_t kExpAvgIndex = 8;
43 constexpr size_t kExpAvgSQRowIndex = 9;
44 constexpr size_t kExpAvgSQColIndex = 10;
45 constexpr size_t kExpAvgSQIndex = 11;
46 constexpr size_t kGlobalNormIndex = 12;
47 constexpr size_t kWorkSpaceUpdateIndex = 0;
48 constexpr size_t kWorkSpaceRFactorIndex = 1;
49 constexpr size_t kWorkSpaceCFactorIndex = 2;
50 auto constexpr kEnableScaleParameter = "enable_scale_parameter";
51 auto constexpr kEnableFirstMoment = "enable_first_moment";
52 auto constexpr kEnableWeightDecay = "enable_weight_decay";
53 } // namespace
54
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)55 bool FusedAdaFactorCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
56 const std::vector<KernelTensor *> &outputs) {
57 param_dtype_ = inputs[kParamIndex]->dtype_id();
58
59 enable_scale_parameter_ = GetValue<bool>(primitive_->GetAttr(kEnableScaleParameter));
60 enable_first_moment_ = GetValue<bool>(primitive_->GetAttr(kEnableFirstMoment));
61 enable_weight_decay_ = GetValue<bool>(primitive_->GetAttr(kEnableWeightDecay));
62 return true;
63 }
64
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)65 int FusedAdaFactorCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
66 const std::vector<KernelTensor *> &outputs) {
67 auto ret = KernelMod::Resize(inputs, outputs);
68 if (ret != 0) {
69 return ret;
70 }
71
72 auto shape = inputs[kParamIndex]->GetShapeVector();
73 elem_num_ = std::accumulate(shape.begin(), shape.end(), 1UL, std::multiplies<size_t>());
74 if (elem_num_ < 1) {
75 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the elem num of 'param' can not be zero.";
76 }
77 if (shape.size() >= kLastColIndex) {
78 need_factor_ = true;
79 last_row_dim_size_ = LongToSize(shape[shape.size() - kLastRowIndex]);
80 last_col_dim_size_ = LongToSize(shape[shape.size() - kLastColIndex]);
81 if (last_row_dim_size_ < 1 || last_col_dim_size_ < 1) {
82 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of 'param' can not be zero.";
83 }
84 }
85
86 workspace_size_list_.clear();
87 (void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat32);
88 (void)workspace_size_list_.emplace_back((elem_num_ / last_row_dim_size_) * kSizeFloat32);
89 (void)workspace_size_list_.emplace_back((elem_num_ / last_col_dim_size_) * kSizeFloat32);
90 return KRET_OK;
91 }
92
93 template <typename T>
CalcRMS(const T * input,size_t elem_num) const94 float FusedAdaFactorCpuKernelMod::CalcRMS(const T *input, size_t elem_num) const {
95 if (elem_num == 0 || input == nullptr) {
96 return 0.0f;
97 }
98 auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
99 size_t thread_num =
100 elem_num < kBatchSize * max_thread_num ? (elem_num + kBatchSize - 1) / kBatchSize : max_thread_num;
101 std::vector<common::Task> tasks;
102 size_t batch_size = (elem_num + thread_num - 1) / thread_num;
103 std::vector<float> block_sum(thread_num, 0.0f);
104 for (size_t thread_id = 0; thread_id < thread_num; ++thread_id) {
105 size_t start = batch_size * thread_id;
106 size_t end = (start + batch_size) > elem_num ? elem_num : (start + batch_size);
107 auto block = [&, start, end, thread_id]() {
108 float square_sum = 0;
109 for (size_t i = start; i < end; ++i) {
110 auto tmp = static_cast<float>(input[i]);
111 square_sum += tmp * tmp;
112 }
113 block_sum[thread_id] = square_sum;
114 return common::SUCCESS;
115 };
116 (void)tasks.emplace_back(block);
117 }
118 ParallelLaunch(tasks);
119 auto rms = std::accumulate(block_sum.begin(), block_sum.end(), 0.0f);
120 rms = rms / elem_num;
121 return std::sqrt(rms);
122 }
123
124 template <typename T>
FactorUpdate(float * update,const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspaces) const125 void FusedAdaFactorCpuKernelMod::FactorUpdate(float *update, const std::vector<KernelTensor *> &inputs,
126 const std::vector<KernelTensor *> &workspaces) const {
127 auto beta2t = reinterpret_cast<float *>(inputs[kBeta2tIndex]->device_ptr())[kScalarIndex];
128 auto grad = reinterpret_cast<T *>(inputs[kGradIndex]->device_ptr());
129 auto exp_avg_sq_row = reinterpret_cast<T *>(inputs[kExpAvgSQRowIndex]->device_ptr());
130 auto exp_avg_sq_col = reinterpret_cast<T *>(inputs[kExpAvgSQColIndex]->device_ptr());
131 auto r_factor = reinterpret_cast<float *>(workspaces[kWorkSpaceRFactorIndex]->device_ptr());
132 auto c_factor = reinterpret_cast<float *>(workspaces[kWorkSpaceCFactorIndex]->device_ptr());
133 auto one_minus_beta2t = 1 - beta2t;
134
135 std::function<void(size_t, size_t)> task;
136 size_t exp_avg_sq_row_elem_num = elem_num_ / last_row_dim_size_;
137 size_t exp_avg_sq_col_elem_num = elem_num_ / last_col_dim_size_;
138 size_t last_row_col_size = last_row_dim_size_ * last_col_dim_size_;
139 size_t row_dim_size = last_row_dim_size_;
140 size_t col_dim_size = last_col_dim_size_;
141 // calc exp_avg_sq_row
142 task = [&](size_t start, size_t end) {
143 for (size_t i = start; i < end; ++i) {
144 float row_reduce = 0;
145 size_t reduce_start = i * row_dim_size;
146 for (size_t j = 0; j < row_dim_size; ++j) {
147 row_reduce += update[reduce_start + j];
148 }
149 row_reduce = row_reduce / row_dim_size;
150 auto tmp = static_cast<float>(exp_avg_sq_row[i]) * beta2t + row_reduce * one_minus_beta2t;
151 exp_avg_sq_row[i] = static_cast<T>(tmp);
152 }
153 };
154 CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num, kBatchSize);
155
156 // calc r_factor
157 task = [&](size_t start, size_t end) {
158 for (size_t i = start; i < end; ++i) {
159 float col_reduce = 0;
160 size_t reduce_start = i * col_dim_size;
161 for (size_t j = 0; j < col_dim_size; ++j) {
162 col_reduce += static_cast<float>(exp_avg_sq_row[reduce_start + j]);
163 }
164 col_reduce = col_reduce / col_dim_size;
165 col_reduce = std::max(col_reduce, kEps);
166 for (size_t j = 0; j < col_dim_size; ++j) {
167 r_factor[reduce_start + j] = std::sqrt(static_cast<float>(exp_avg_sq_row[reduce_start + j]) / col_reduce);
168 }
169 }
170 };
171 CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num / col_dim_size, kBatchSize);
172
173 // calc exp_avg_sq_col and c_factor
174 task = [&](size_t start, size_t end) {
175 for (size_t i = start; i < end; ++i) {
176 float row_reduce = 0;
177 size_t reduce_start = (i / row_dim_size) * last_row_col_size + i % row_dim_size;
178 for (size_t j = 0; j < col_dim_size; ++j) {
179 row_reduce += update[reduce_start + j * row_dim_size];
180 }
181 row_reduce = row_reduce / col_dim_size;
182 auto tmp = static_cast<float>(exp_avg_sq_col[i]) * beta2t + row_reduce * one_minus_beta2t;
183 tmp = std::max(tmp, kEps);
184 exp_avg_sq_col[i] = static_cast<T>(tmp);
185 c_factor[i] = std::sqrt(tmp);
186 }
187 };
188 CPUKernelUtils::ParallelFor(task, exp_avg_sq_col_elem_num, kBatchSize);
189
190 // calc update
191 task = [&, this](size_t start, size_t end) {
192 for (size_t i = start; i < end; ++i) {
193 size_t row_i = i % row_dim_size;
194 size_t col_i = i / row_dim_size % col_dim_size;
195 size_t slice = i / last_row_col_size;
196 auto norm = r_factor[slice * col_dim_size + col_i] * c_factor[slice * row_dim_size + row_i];
197 update[i] = static_cast<float>(grad[i]) * global_norm_reciprocal_ / std::max(norm, kEps);
198 }
199 };
200 CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
201 }
202
203 template <typename T>
LaunchKernel(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspaces,const std::vector<KernelTensor * > &)204 void FusedAdaFactorCpuKernelMod::LaunchKernel(const std::vector<KernelTensor *> &inputs,
205 const std::vector<KernelTensor *> &workspaces,
206 const std::vector<KernelTensor *> &) {
207 auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->device_ptr());
208 auto clip_threshold = reinterpret_cast<float *>(inputs[kClipThresholdIndex]->device_ptr())[kScalarIndex];
209 auto beta1 = reinterpret_cast<float *>(inputs[kBeta1Index]->device_ptr())[kScalarIndex];
210 auto beta2t = reinterpret_cast<float *>(inputs[kBeta2tIndex]->device_ptr())[kScalarIndex];
211 auto weight_decay = reinterpret_cast<float *>(inputs[kWeightDecayIndex]->device_ptr())[kScalarIndex];
212 auto learning_rate = reinterpret_cast<float *>(inputs[kLearningRateIndex]->device_ptr())[kScalarIndex];
213 auto grad = reinterpret_cast<T *>(inputs[kGradIndex]->device_ptr());
214 auto param = reinterpret_cast<T *>(inputs[kParamIndex]->device_ptr());
215 auto exp_avg = reinterpret_cast<T *>(inputs[kExpAvgIndex]->device_ptr());
216 auto exp_avg_sq = reinterpret_cast<T *>(inputs[kExpAvgSQIndex]->device_ptr());
217 auto update = reinterpret_cast<float *>(workspaces[kWorkSpaceUpdateIndex]->device_ptr());
218 auto one_minus_beta1 = 1 - beta1;
219 auto one_minus_beta2t = 1 - beta2t;
220 if (clip_threshold <= 0) {
221 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', clip threshold " << clip_threshold << " is invalid. ";
222 }
223 if (beta1 < 0 || one_minus_beta1 < 0) {
224 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', beta1 " << beta1 << " is invalid. ";
225 }
226 if (beta2t < 0 || one_minus_beta2t < 0) {
227 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', beta2t " << beta2t << " is invalid. ";
228 }
229 if (epsilon[0] < 0 || epsilon[1] < 0) {
230 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', epsilon (" << epsilon[0] << "," << epsilon[1]
231 << ") is invalid. ";
232 }
233
234 std::function<void(size_t, size_t)> task;
235 // calc update
236 task = [&, this](size_t start, size_t end) {
237 for (size_t i = start; i < end; ++i) {
238 auto tmp = static_cast<float>(grad[i]) * global_norm_reciprocal_;
239 update[i] = tmp * tmp + epsilon[0];
240 }
241 };
242 CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
243
244 if (need_factor_) {
245 FactorUpdate<T>(update, inputs, workspaces);
246 } else {
247 // no factor
248 task = [&, this](size_t start, size_t end) {
249 for (size_t i = start; i < end; ++i) {
250 auto tmp = static_cast<float>(exp_avg_sq[i]) * beta2t + update[i] * one_minus_beta2t;
251 tmp = std::max(tmp, kEps);
252 exp_avg_sq[i] = static_cast<T>(tmp);
253 update[i] = static_cast<float>(grad[i]) * global_norm_reciprocal_ / std::sqrt(tmp);
254 }
255 };
256 CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
257 }
258
259 // scale learning rate with rms of param
260 if (enable_scale_parameter_) {
261 auto rms = CalcRMS(param, elem_num_);
262 learning_rate = learning_rate * std::max(epsilon[1], rms);
263 }
264
265 // update param
266 auto update_rms = CalcRMS(update, elem_num_);
267 auto update_rms_threshold = update_rms / clip_threshold;
268 auto update_coff = learning_rate / std::max(update_rms_threshold, 1.0f);
269 task = [&, this](size_t start, size_t end) {
270 for (size_t i = start; i < end; ++i) {
271 update[i] = update[i] * update_coff;
272 if (enable_first_moment_) {
273 update[i] = static_cast<float>(exp_avg[i]) * beta1 + update[i] * one_minus_beta1;
274 exp_avg[i] = static_cast<T>(update[i]);
275 }
276 if (enable_weight_decay_) {
277 auto tmp = update[i] + static_cast<float>(param[i]) * weight_decay * learning_rate;
278 param[i] = static_cast<T>(static_cast<float>(param[i]) - tmp);
279 } else {
280 param[i] = static_cast<T>(static_cast<float>(param[i]) - update[i]);
281 }
282 }
283 };
284 CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
285 }
286
Launch(const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > & workspaces,const std::vector<kernel::KernelTensor * > & outputs)287 bool FusedAdaFactorCpuKernelMod::Launch(const std::vector<kernel::KernelTensor *> &inputs,
288 const std::vector<kernel::KernelTensor *> &workspaces,
289 const std::vector<kernel::KernelTensor *> &outputs) {
290 if (inputs.size() == kStandardInputNum + 1) {
291 auto global_norm = reinterpret_cast<float *>(inputs[kGlobalNormIndex]->device_ptr())[kScalarIndex];
292 if (global_norm < kEps) {
293 global_norm_reciprocal_ = 1.0f;
294 } else {
295 global_norm_reciprocal_ = 1.0f / global_norm;
296 }
297 }
298
299 CheckInputAddresses(inputs);
300 CheckWorkspaceAddresses(workspaces);
301 if (param_dtype_ == kNumberTypeFloat16) {
302 LaunchKernel<float16>(inputs, workspaces, outputs);
303 } else {
304 LaunchKernel<float>(inputs, workspaces, outputs);
305 }
306 return true;
307 }
308
CheckInputAddresses(const std::vector<kernel::KernelTensor * > & inputs) const309 void FusedAdaFactorCpuKernelMod::CheckInputAddresses(const std::vector<kernel::KernelTensor *> &inputs) const {
310 if (inputs.size() < kStandardInputNum) {
311 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be at least " << kStandardInputNum
312 << ", but got: " << inputs.size();
313 }
314
315 if (inputs[kEpsIndex]->size() != kSizeFloat32 << 1) {
316 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'epsilon' must be " << (kSizeFloat32 << 1)
317 << ", but got " << inputs[kEpsIndex]->size();
318 }
319 if (inputs[kClipThresholdIndex]->size() != kSizeFloat32) {
320 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'clip_threshold' must be " << kSizeFloat32
321 << ", but got " << inputs[kClipThresholdIndex]->size();
322 }
323 if (inputs[kBeta1Index]->size() != kSizeFloat32) {
324 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' must be " << kSizeFloat32
325 << ", but got " << inputs[kBeta1Index]->size();
326 }
327 if (inputs[kBeta2tIndex]->size() != kSizeFloat32) {
328 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta2t' must be " << kSizeFloat32
329 << ", but got " << inputs[kBeta2tIndex]->size();
330 }
331 if (inputs[kWeightDecayIndex]->size() != kSizeFloat32) {
332 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'weight_decay' must be " << kSizeFloat32
333 << ", but got " << inputs[kWeightDecayIndex]->size();
334 }
335 if (inputs[kLearningRateIndex]->size() != kSizeFloat32) {
336 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'lr' must be " << kSizeFloat32
337 << ", but got " << inputs[kLearningRateIndex]->size();
338 }
339
340 size_t param_size = param_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem_num_ * kSizeFloat32;
341 if (inputs[kParamIndex]->size() != param_size) {
342 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'param' must be " << param_size
343 << ", but got " << inputs[kParamIndex]->size();
344 }
345 if (inputs[kGradIndex]->size() != param_size) {
346 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' must be " << param_size
347 << ", but got " << inputs[kGradIndex]->size();
348 }
349
350 if (enable_first_moment_ && inputs[kExpAvgIndex]->size() != param_size) {
351 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg' must be " << param_size
352 << ", but got " << inputs[kExpAvgIndex]->size();
353 }
354
355 if (!need_factor_) {
356 if (inputs[kExpAvgSQIndex]->size() != param_size) {
357 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq' must be " << param_size
358 << ", but got " << inputs[kExpAvgSQIndex]->size();
359 }
360 return;
361 }
362
363 if (inputs[kExpAvgSQRowIndex]->size() != param_size / last_row_dim_size_) {
364 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq_row' must be "
365 << param_size / last_row_dim_size_ << ", but got " << inputs[kExpAvgSQRowIndex]->size();
366 }
367 if (inputs[kExpAvgSQColIndex]->size() != param_size / last_col_dim_size_) {
368 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq_col' must be "
369 << param_size / last_col_dim_size_ << ", but got " << inputs[kExpAvgSQColIndex]->size();
370 }
371 }
372
CheckWorkspaceAddresses(const std::vector<kernel::KernelTensor * > & workspaces) const373 void FusedAdaFactorCpuKernelMod::CheckWorkspaceAddresses(const std::vector<kernel::KernelTensor *> &workspaces) const {
374 if (workspaces.size() != kWorkSpaceNum) {
375 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of workspaces must be " << kWorkSpaceNum
376 << ", but got: " << workspaces.size();
377 }
378
379 size_t update_size = elem_num_ * kSizeFloat32;
380 if (workspaces[kWorkSpaceUpdateIndex]->size() != elem_num_ * kSizeFloat32) {
381 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'update ' must be " << update_size
382 << ", but got " << workspaces[kWorkSpaceUpdateIndex]->size();
383 }
384
385 if (workspaces[kWorkSpaceRFactorIndex]->size() != update_size / last_row_dim_size_) {
386 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'r_factor' must be "
387 << update_size / last_row_dim_size_ << ", but got " << workspaces[kWorkSpaceRFactorIndex]->size();
388 }
389 if (workspaces[kWorkSpaceCFactorIndex]->size() != update_size / last_col_dim_size_) {
390 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'c_factor' must be "
391 << update_size / last_col_dim_size_ << ", but got " << workspaces[kWorkSpaceCFactorIndex]->size();
392 }
393 }
394
GetOpSupport()395 std::vector<KernelAttr> FusedAdaFactorCpuKernelMod::GetOpSupport() {
396 static std::map<std::string, std::vector<KernelAttr>> support_list_map = {{kFusedAdaFactor,
397 {KernelAttr()
398 .AddInputAttr(kNumberTypeFloat32)
399 .AddInputAttr(kNumberTypeFloat32)
400 .AddInputAttr(kNumberTypeFloat32)
401 .AddInputAttr(kNumberTypeFloat32)
402 .AddInputAttr(kNumberTypeFloat32)
403 .AddInputAttr(kNumberTypeFloat32)
404 .AddInputAttr(kNumberTypeFloat32)
405 .AddInputAttr(kNumberTypeFloat32)
406 .AddInputAttr(kNumberTypeFloat32)
407 .AddInputAttr(kNumberTypeFloat32)
408 .AddInputAttr(kNumberTypeFloat32)
409 .AddInputAttr(kNumberTypeFloat32)
410 .AddOutputAttr(kNumberTypeFloat32),
411 KernelAttr()
412 .AddInputAttr(kNumberTypeFloat32)
413 .AddInputAttr(kNumberTypeFloat32)
414 .AddInputAttr(kNumberTypeFloat32)
415 .AddInputAttr(kNumberTypeFloat32)
416 .AddInputAttr(kNumberTypeFloat32)
417 .AddInputAttr(kNumberTypeFloat32)
418 .AddInputAttr(kNumberTypeFloat16)
419 .AddInputAttr(kNumberTypeFloat16)
420 .AddInputAttr(kNumberTypeFloat16)
421 .AddInputAttr(kNumberTypeFloat16)
422 .AddInputAttr(kNumberTypeFloat16)
423 .AddInputAttr(kNumberTypeFloat16)
424 .AddOutputAttr(kNumberTypeFloat16)}},
425 {kFusedAdaFactorWithGlobalNorm,
426 {KernelAttr()
427 .AddInputAttr(kNumberTypeFloat32)
428 .AddInputAttr(kNumberTypeFloat32)
429 .AddInputAttr(kNumberTypeFloat32)
430 .AddInputAttr(kNumberTypeFloat32)
431 .AddInputAttr(kNumberTypeFloat32)
432 .AddInputAttr(kNumberTypeFloat32)
433 .AddInputAttr(kNumberTypeFloat32)
434 .AddInputAttr(kNumberTypeFloat32)
435 .AddInputAttr(kNumberTypeFloat32)
436 .AddInputAttr(kNumberTypeFloat32)
437 .AddInputAttr(kNumberTypeFloat32)
438 .AddInputAttr(kNumberTypeFloat32)
439 .AddInputAttr(kNumberTypeFloat32)
440 .AddOutputAttr(kNumberTypeFloat32),
441 KernelAttr()
442 .AddInputAttr(kNumberTypeFloat32)
443 .AddInputAttr(kNumberTypeFloat32)
444 .AddInputAttr(kNumberTypeFloat32)
445 .AddInputAttr(kNumberTypeFloat32)
446 .AddInputAttr(kNumberTypeFloat32)
447 .AddInputAttr(kNumberTypeFloat32)
448 .AddInputAttr(kNumberTypeFloat16)
449 .AddInputAttr(kNumberTypeFloat16)
450 .AddInputAttr(kNumberTypeFloat16)
451 .AddInputAttr(kNumberTypeFloat16)
452 .AddInputAttr(kNumberTypeFloat16)
453 .AddInputAttr(kNumberTypeFloat16)
454 .AddInputAttr(kNumberTypeFloat32)
455 .AddOutputAttr(kNumberTypeFloat16)}}};
456 auto iter = support_list_map.find(kernel_type_);
457 if (iter == support_list_map.end()) {
458 MS_LOG(EXCEPTION) << "Does not support " << kernel_type_ << "!";
459 }
460
461 return iter->second;
462 }
463
464 MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FusedAdaFactor,
__anona3ea948f0a02() 465 []() { return std::make_shared<FusedAdaFactorCpuKernelMod>(kFusedAdaFactor); });
__anona3ea948f0b02() 466 MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FusedAdaFactorWithGlobalNorm, []() {
467 return std::make_shared<FusedAdaFactorCpuKernelMod>(kFusedAdaFactorWithGlobalNorm);
468 });
469 } // namespace kernel
470 } // namespace mindspore
471