1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_ADA_FACTOR_CPU_KERNEL_H_ 17 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_ADA_FACTOR_CPU_KERNEL_H_ 18 19 #include <vector> 20 #include <memory> 21 #include <map> 22 #include <string> 23 #include "plugin/device/cpu/kernel/cpu_kernel.h" 24 #include "plugin/factory/ms_factory.h" 25 26 namespace mindspore { 27 namespace kernel { 28 constexpr auto kFusedAdaFactor = "FusedAdaFactor"; 29 constexpr auto kFusedAdaFactorWithGlobalNorm = "FusedAdaFactorWithGlobalNorm"; 30 constexpr auto kUnknown = "Unknown"; 31 class FusedAdaFactorCpuKernelMod : public NativeCpuKernelMod { 32 public: 33 FusedAdaFactorCpuKernelMod() = default; FusedAdaFactorCpuKernelMod(const std::string & kernel_type)34 explicit FusedAdaFactorCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} 35 ~FusedAdaFactorCpuKernelMod() override = default; 36 bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspaces, 37 const std::vector<KernelTensor *> &outputs) override; 38 bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override; 39 int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override; 40 41 std::vector<KernelAttr> GetOpSupport() override; 42 43 private: 44 void CheckInputAddresses(const std::vector<KernelTensor *> &inputs) const; 45 void CheckWorkspaceAddresses(const std::vector<KernelTensor *> &workspaces) const; 46 47 template <typename T> 48 void LaunchKernel(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspaces, 49 const std::vector<KernelTensor *> &outputs); 50 51 template <typename T> 52 float CalcRMS(const T *input, size_t elem_num) const; 53 54 template <typename T> 55 void FactorUpdate(float *update, const std::vector<KernelTensor *> &inputs, 56 const std::vector<KernelTensor *> &workspaces) const; 57 58 bool enable_scale_parameter_{false}; 59 bool enable_first_moment_{false}; 60 bool enable_weight_decay_{false}; 61 bool need_factor_{false}; 62 size_t elem_num_{0}; 63 size_t last_row_dim_size_{1}; 64 size_t last_col_dim_size_{1}; 65 TypeId param_dtype_{kTypeUnknown}; 66 float global_norm_reciprocal_{1.0f}; 67 std::string kernel_type_{kUnknown}; 68 }; 69 } // namespace kernel 70 } // namespace mindspore 71 72 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_ADA_FACTOR_CPU_KERNEL_H_ 73