1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "ops/op_enum.h"
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "mindapi/base/types.h"
22 #include "utils/check_convert_utils.h"
23 #include "mindapi/base/format.h"
24
25 namespace mindspore {
26 namespace ops {
27
28 namespace {
29 using StrToEnumMap = std::unordered_map<std::string, int64_t>;
30
31 class RegStringToEnumHelper {
32 public:
33 template <typename T>
AddValues(T && string_to_enum,const std::string & key="")34 std::string AddValues(T &&string_to_enum, const std::string &key = "") {
35 auto &string_to_enum_target = string_to_enum_memory_[key];
36 for (const auto &kv : string_to_enum) {
37 if (string_to_enum_target.find(kv.first) != string_to_enum_target.end()) {
38 MS_LOG(EXCEPTION) << kv.first << " has been registered!";
39 }
40 }
41 string_to_enum_target.merge(std::move(string_to_enum));
42 return "";
43 }
44
GetValues(const std::string & key="")45 const StrToEnumMap &GetValues(const std::string &key = "") {
46 auto it = string_to_enum_memory_.find(key);
47 if (it != string_to_enum_memory_.end()) {
48 return it->second;
49 }
50 return string_to_enum_memory_[""];
51 }
52
53 private:
54 std::unordered_map<std::string, StrToEnumMap> string_to_enum_memory_;
55 };
56 RegStringToEnumHelper reg_string_to_enum_helper;
57
58 #define REG_STRING_TO_ENUM_COMMON(enum_type, ...) \
59 const auto op_enum_##enum_type = reg_string_to_enum_helper.AddValues(__VA_ARGS__);
60
61 #define REG_STRING_TO_ENUM_SPECIAL(enum_type, ...) \
62 const auto op_enum_##enum_type = reg_string_to_enum_helper.AddValues(__VA_ARGS__, #enum_type);
63
64 // Convert to uppercase uniformly
StrToUpper(const std::string & str)65 inline std::string StrToUpper(const std::string &str) {
66 auto res = str;
67 for (auto &c : res) {
68 c = std::toupper(c);
69 }
70 return res;
71 }
72
73 // Format
GetStringToFormatMap()74 inline std::unordered_map<std::string, int64_t> GetStringToFormatMap() {
75 const auto &names = GetFormatNames();
76 std::unordered_map<std::string, int64_t> map{{"DEFAULT_FORMAT", static_cast<int64_t>(Format::DEFAULT_FORMAT)}};
77 for (size_t i = 0; i < names.size(); ++i) {
78 map[StrToUpper(names[i])] = static_cast<int64_t>(i);
79 }
80 return map;
81 }
82 REG_STRING_TO_ENUM_COMMON(format, GetStringToFormatMap())
83
84 // RoundingMode
85 StrToEnumMap StrToRoundingModeMap = {{"ROUND", RoundingMode::ROUND},
86 {"TRUNC", RoundingMode::TRUNC},
87 {"FLOOR", RoundingMode::FLOOR},
88 {"CEIL", RoundingMode::CEIL}};
89 REG_STRING_TO_ENUM_COMMON(rounding_mode, StrToRoundingModeMap)
90
91 // PadMode
92 StrToEnumMap StrToPadModeMap = {
93 {"PAD", PadMode::PAD}, {"SAME", PadMode::SAME}, {"VALID", PadMode::VALID}, {"FULL", PadMode::FULL}};
94 REG_STRING_TO_ENUM_COMMON(pad_mode, StrToPadModeMap)
95
96 // Reduction
97 StrToEnumMap StrToReductionMap = {{"SUM", Reduction::REDUCTION_SUM},
98 {"MEAN", Reduction::MEAN},
99 {"NONE", Reduction::NONE},
100 {"UPDATE", Reduction::UPDATE}};
101 REG_STRING_TO_ENUM_COMMON(reduction, StrToReductionMap)
102
103 // Activation
104 StrToEnumMap StrToActivationMap = {{"NO_ACTIVATION", ActivationType::NO_ACTIVATION},
105 {"RELU", ActivationType::RELU},
106 {"SIGMOID", ActivationType::SIGMOID},
107 {"RELU6", ActivationType::RELU6},
108 {"ELU", ActivationType::ELU},
109 {"LEAKY_RELU", ActivationType::LEAKY_RELU},
110 {"ABS", ActivationType::ABS},
111 {"RELU1", ActivationType::RELU1},
112 {"SOFTSIGN", ActivationType::SOFTSIGN},
113 {"SOFTPLUS", ActivationType::SOFTPLUS},
114 {"TANH", ActivationType::TANH},
115 {"SELU", ActivationType::SELU},
116 {"HSWISH", ActivationType::HSWISH},
117 {"HSIGMOID", ActivationType::HSIGMOID},
118 {"THRESHOLDRELU", ActivationType::THRESHOLDRELU},
119 {"LINEAR", ActivationType::LINEAR},
120 {"HARD_TANH", ActivationType::HARD_TANH},
121 {"SIGN", ActivationType::SIGN},
122 {"SWISH", ActivationType::SWISH},
123 {"GELU", ActivationType::GELU},
124 {"GLU", ActivationType::GLU},
125 {"UNKNOWN", ActivationType::UNKNOWN},
126 {"FASTGELU", ActivationType::FASTGELU},
127 {"SILU", ActivationType::SILU},
128 {"GEGLU", ActivationType::GEGLU},
129 {"SWIGLU", ActivationType::SWIGLU},
130 {"REGLU", ActivationType::REGLU}};
131 REG_STRING_TO_ENUM_COMMON(activation, StrToActivationMap)
132
133 // GateOrder
134 REG_STRING_TO_ENUM_COMMON(gate_order, StrToEnumMap{{"RZH", GateOrderMode::RZH}, {"ZRH", GateOrderMode::ZRH}})
135
136 // CoordinateTransformationMode
137 StrToEnumMap StrToCoordinateTransformationModeMap = {{"ASYMMETRIC", CoordinateTransformMode::ASYMMETRIC},
138 {"ALIGN_CORNERS", CoordinateTransformMode::ALIGN_CORNERS},
139 {"HALF_PIXEL", CoordinateTransformMode::HALF_PIXEL},
140 {"CROP_AND_RESIZE", CoordinateTransformMode::CROP_AND_RESIZE}};
141 REG_STRING_TO_ENUM_COMMON(coordinate_transformation_mode, StrToCoordinateTransformationModeMap)
142
143 // PaddingMode
144 StrToEnumMap StrToPaddingModeMap = {{"CONSTANT", PaddingMode::CONSTANT},
145 {"REFLECT", PaddingMode::REFLECT},
146 {"SYMMETRIC", PaddingMode::SYMMETRIC},
147 {"MODE_RESERVED", PaddingMode::MODE_RESERVED}};
148 REG_STRING_TO_ENUM_COMMON(padding_mode, StrToPaddingModeMap)
149
150 // Direction
151 REG_STRING_TO_ENUM_COMMON(direction, StrToEnumMap{{"UNIDIRECTIONAL", Direction::UNIDIRECTIONAL}})
152
153 // CellType
154 REG_STRING_TO_ENUM_COMMON(cell_type, StrToEnumMap{{"LSTM", CellType::CELL_TYPE_LSTM}})
155
156 // Group
157 REG_STRING_TO_ENUM_COMMON(group, StrToEnumMap{{"SYNC_BN_GROUP0", Group::SYNC_BN_GROUP0}})
158
159 // InterpolationMode
160 REG_STRING_TO_ENUM_COMMON(interpolation_mode, StrToEnumMap{{"BILINEAR", InterpolationMode::BILINEAR},
161 {"NEAREST", InterpolationMode::NEAREST}})
162
163 // NormMode
164 StrToEnumMap StrToNormModeMap = {
165 {"BACKWARD", NormMode::BACKWARD}, {"FORWARD", NormMode::FORWARD}, {"ORTHO", NormMode::ORTHO}};
166 REG_STRING_TO_ENUM_COMMON(norm_mode, StrToNormModeMap)
167
168 // GridSamplerPaddingMode
169 StrToEnumMap StrToGridSamplerPaddingMode = {{"ZEROS", GridSamplerPaddingMode::ZEROS},
170 {"BORDER", GridSamplerPaddingMode::BORDER},
171 {"REFLECTION", GridSamplerPaddingMode::REFLECTION}};
172 REG_STRING_TO_ENUM_COMMON(grid_sampler_padding_mode, StrToGridSamplerPaddingMode)
173
174 // KVCacheAlignMode
175 REG_STRING_TO_ENUM_COMMON(k_v_cache_align_mode,
176 StrToEnumMap{{"LEFT", KVCacheAlignMode::LEFT}, {"RIGHT", KVCacheAlignMode::RIGHT}})
177
178 REG_STRING_TO_ENUM_COMMON(fas_input_layout_mode, StrToEnumMap{{"BSH", FASInputLayoutMode::BSH},
179 {"BNSD", FASInputLayoutMode::BNSD},
180 {"SBH", FASInputLayoutMode::SBH},
181 {"BSND", FASInputLayoutMode::BSND},
182 {"TND", FASInputLayoutMode::TND}})
183 } // namespace
184
StringToEnumImpl(const std::string & op_name,const std::string & arg_name,const std::string & enum_string)185 int64_t StringToEnumImpl(const std::string &op_name, const std::string &arg_name, const std::string &enum_string) {
186 const auto &string_to_enum_map = reg_string_to_enum_helper.GetValues(arg_name);
187 const auto enum_val_iter = string_to_enum_map.find(StrToUpper(enum_string));
188 if (enum_val_iter == string_to_enum_map.end()) {
189 MS_EXCEPTION(ValueError) << "Failed to convert the value \"" << enum_string << "\" of input '" << arg_name
190 << "' of '" << op_name << "' to enum.";
191 }
192 return enum_val_iter->second;
193 }
194 } // namespace ops
195 } // namespace mindspore
196