1 /** 2 * Copyright 2022-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_INCLUDE_API_CFG_H 17 #define MINDSPORE_INCLUDE_API_CFG_H 18 19 #include <cstddef> 20 #include <string> 21 #include <vector> 22 #include <memory> 23 #include "include/api/data_type.h" 24 #include "include/api/dual_abi_helper.h" 25 #include "include/api/types.h" 26 27 namespace mindspore { 28 constexpr int iter_th = 1000; 29 class MS_API MixPrecisionCfg { 30 public: MixPrecisionCfg()31 MixPrecisionCfg() { 32 this->dynamic_loss_scale_ = false; 33 this->loss_scale_ = 128.0f; 34 this->keep_batchnorm_fp32_ = true; 35 this->num_of_not_nan_iter_th_ = iter_th; 36 } MixPrecisionCfg(const MixPrecisionCfg & rhs)37 MixPrecisionCfg(const MixPrecisionCfg &rhs) { 38 this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_; 39 this->loss_scale_ = rhs.loss_scale_; 40 this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_; 41 this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_; 42 } 43 ~MixPrecisionCfg() = default; 44 45 bool dynamic_loss_scale_ = false; /**< Enable/disable dynamic loss scale during mix precision training */ 46 float loss_scale_; /**< Initial loss scale factor */ 47 bool keep_batchnorm_fp32_ = true; /**< Keep batch norm in FP32 while training */ 48 uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */ 49 bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore */ 50 }; 51 52 class MS_API TrainCfg { 53 public: 54 TrainCfg() = default; TrainCfg(const TrainCfg & rhs)55 TrainCfg(const TrainCfg &rhs) { 56 this->loss_name_ = rhs.loss_name_; 57 this->mix_precision_cfg_ = rhs.mix_precision_cfg_; 58 this->accumulate_gradients_ = rhs.accumulate_gradients_; 59 } 60 ~TrainCfg() = default; 61 62 /// \brief obtain part of the name that identify a loss kernel. 63 /// 64 /// \return loss_name. 65 inline std::vector<std::string> GetLossName() const; 66 /// \brief Set part of the name that identify a loss kernel. 67 /// 68 /// \param[in] loss_name define part of the name that identify a loss kernel. 69 inline void SetLossName(const std::vector<std::string> &loss_name); 70 71 OptimizationLevel optimization_level_ = kO0; 72 MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */ 73 bool accumulate_gradients_ = false; 74 75 private: 76 std::vector<std::vector<char>> loss_name_ = VectorStringToChar({"loss_fct", "_loss_fn", "SigmoidCrossEntropy"}); 77 }; 78 GetLossName()79std::vector<std::string> TrainCfg::GetLossName() const { return VectorCharToString(loss_name_); } SetLossName(const std::vector<std::string> & loss_name)80void TrainCfg::SetLossName(const std::vector<std::string> &loss_name) { loss_name_ = VectorStringToChar(loss_name); } 81 } // namespace mindspore 82 #endif // MINDSPORE_INCLUDE_API_CFG_H 83