• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ops/op_enum.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "mindapi/base/types.h"
22 #include "utils/check_convert_utils.h"
23 #include "mindapi/base/format.h"
24 
25 namespace mindspore {
26 namespace ops {
27 
28 namespace {
29 using StrToEnumMap = std::unordered_map<std::string, int64_t>;
30 
31 class RegStringToEnumHelper {
32  public:
33   template <typename T>
AddValues(T && string_to_enum,const std::string & key="")34   std::string AddValues(T &&string_to_enum, const std::string &key = "") {
35     auto &string_to_enum_target = string_to_enum_memory_[key];
36     for (const auto &kv : string_to_enum) {
37       if (string_to_enum_target.find(kv.first) != string_to_enum_target.end()) {
38         MS_LOG(EXCEPTION) << kv.first << " has been registered!";
39       }
40     }
41     string_to_enum_target.merge(std::move(string_to_enum));
42     return "";
43   }
44 
GetValues(const std::string & key="")45   const StrToEnumMap &GetValues(const std::string &key = "") {
46     auto it = string_to_enum_memory_.find(key);
47     if (it != string_to_enum_memory_.end()) {
48       return it->second;
49     }
50     return string_to_enum_memory_[""];
51   }
52 
53  private:
54   std::unordered_map<std::string, StrToEnumMap> string_to_enum_memory_;
55 };
56 RegStringToEnumHelper reg_string_to_enum_helper;
57 
58 #define REG_STRING_TO_ENUM_COMMON(enum_type, ...) \
59   const auto op_enum_##enum_type = reg_string_to_enum_helper.AddValues(__VA_ARGS__);
60 
61 #define REG_STRING_TO_ENUM_SPECIAL(enum_type, ...) \
62   const auto op_enum_##enum_type = reg_string_to_enum_helper.AddValues(__VA_ARGS__, #enum_type);
63 
64 // Convert to uppercase uniformly
StrToUpper(const std::string & str)65 inline std::string StrToUpper(const std::string &str) {
66   auto res = str;
67   for (auto &c : res) {
68     c = std::toupper(c);
69   }
70   return res;
71 }
72 
73 // Format
GetStringToFormatMap()74 inline std::unordered_map<std::string, int64_t> GetStringToFormatMap() {
75   const auto &names = GetFormatNames();
76   std::unordered_map<std::string, int64_t> map{{"DEFAULT_FORMAT", static_cast<int64_t>(Format::DEFAULT_FORMAT)}};
77   for (size_t i = 0; i < names.size(); ++i) {
78     map[StrToUpper(names[i])] = static_cast<int64_t>(i);
79   }
80   return map;
81 }
82 REG_STRING_TO_ENUM_COMMON(format, GetStringToFormatMap())
83 
84 // RoundingMode
85 StrToEnumMap StrToRoundingModeMap = {{"ROUND", RoundingMode::ROUND},
86                                      {"TRUNC", RoundingMode::TRUNC},
87                                      {"FLOOR", RoundingMode::FLOOR},
88                                      {"CEIL", RoundingMode::CEIL}};
89 REG_STRING_TO_ENUM_COMMON(rounding_mode, StrToRoundingModeMap)
90 
91 // PadMode
92 StrToEnumMap StrToPadModeMap = {
93   {"PAD", PadMode::PAD}, {"SAME", PadMode::SAME}, {"VALID", PadMode::VALID}, {"FULL", PadMode::FULL}};
94 REG_STRING_TO_ENUM_COMMON(pad_mode, StrToPadModeMap)
95 
96 // Reduction
97 StrToEnumMap StrToReductionMap = {{"SUM", Reduction::REDUCTION_SUM},
98                                   {"MEAN", Reduction::MEAN},
99                                   {"NONE", Reduction::NONE},
100                                   {"UPDATE", Reduction::UPDATE}};
101 REG_STRING_TO_ENUM_COMMON(reduction, StrToReductionMap)
102 
103 // Activation
104 StrToEnumMap StrToActivationMap = {{"NO_ACTIVATION", ActivationType::NO_ACTIVATION},
105                                    {"RELU", ActivationType::RELU},
106                                    {"SIGMOID", ActivationType::SIGMOID},
107                                    {"RELU6", ActivationType::RELU6},
108                                    {"ELU", ActivationType::ELU},
109                                    {"LEAKY_RELU", ActivationType::LEAKY_RELU},
110                                    {"ABS", ActivationType::ABS},
111                                    {"RELU1", ActivationType::RELU1},
112                                    {"SOFTSIGN", ActivationType::SOFTSIGN},
113                                    {"SOFTPLUS", ActivationType::SOFTPLUS},
114                                    {"TANH", ActivationType::TANH},
115                                    {"SELU", ActivationType::SELU},
116                                    {"HSWISH", ActivationType::HSWISH},
117                                    {"HSIGMOID", ActivationType::HSIGMOID},
118                                    {"THRESHOLDRELU", ActivationType::THRESHOLDRELU},
119                                    {"LINEAR", ActivationType::LINEAR},
120                                    {"HARD_TANH", ActivationType::HARD_TANH},
121                                    {"SIGN", ActivationType::SIGN},
122                                    {"SWISH", ActivationType::SWISH},
123                                    {"GELU", ActivationType::GELU},
124                                    {"GLU", ActivationType::GLU},
125                                    {"UNKNOWN", ActivationType::UNKNOWN},
126                                    {"FASTGELU", ActivationType::FASTGELU},
127                                    {"SILU", ActivationType::SILU},
128                                    {"GEGLU", ActivationType::GEGLU},
129                                    {"SWIGLU", ActivationType::SWIGLU},
130                                    {"REGLU", ActivationType::REGLU}};
131 REG_STRING_TO_ENUM_COMMON(activation, StrToActivationMap)
132 
133 // GateOrder
134 REG_STRING_TO_ENUM_COMMON(gate_order, StrToEnumMap{{"RZH", GateOrderMode::RZH}, {"ZRH", GateOrderMode::ZRH}})
135 
136 // CoordinateTransformationMode
137 StrToEnumMap StrToCoordinateTransformationModeMap = {{"ASYMMETRIC", CoordinateTransformMode::ASYMMETRIC},
138                                                      {"ALIGN_CORNERS", CoordinateTransformMode::ALIGN_CORNERS},
139                                                      {"HALF_PIXEL", CoordinateTransformMode::HALF_PIXEL},
140                                                      {"CROP_AND_RESIZE", CoordinateTransformMode::CROP_AND_RESIZE}};
141 REG_STRING_TO_ENUM_COMMON(coordinate_transformation_mode, StrToCoordinateTransformationModeMap)
142 
143 // PaddingMode
144 StrToEnumMap StrToPaddingModeMap = {{"CONSTANT", PaddingMode::CONSTANT},
145                                     {"REFLECT", PaddingMode::REFLECT},
146                                     {"SYMMETRIC", PaddingMode::SYMMETRIC},
147                                     {"MODE_RESERVED", PaddingMode::MODE_RESERVED}};
148 REG_STRING_TO_ENUM_COMMON(padding_mode, StrToPaddingModeMap)
149 
150 // Direction
151 REG_STRING_TO_ENUM_COMMON(direction, StrToEnumMap{{"UNIDIRECTIONAL", Direction::UNIDIRECTIONAL}})
152 
153 // CellType
154 REG_STRING_TO_ENUM_COMMON(cell_type, StrToEnumMap{{"LSTM", CellType::CELL_TYPE_LSTM}})
155 
156 // Group
157 REG_STRING_TO_ENUM_COMMON(group, StrToEnumMap{{"SYNC_BN_GROUP0", Group::SYNC_BN_GROUP0}})
158 
159 // InterpolationMode
160 REG_STRING_TO_ENUM_COMMON(interpolation_mode, StrToEnumMap{{"BILINEAR", InterpolationMode::BILINEAR},
161                                                            {"NEAREST", InterpolationMode::NEAREST}})
162 
163 // NormMode
164 StrToEnumMap StrToNormModeMap = {
165   {"BACKWARD", NormMode::BACKWARD}, {"FORWARD", NormMode::FORWARD}, {"ORTHO", NormMode::ORTHO}};
166 REG_STRING_TO_ENUM_COMMON(norm_mode, StrToNormModeMap)
167 
168 // GridSamplerPaddingMode
169 StrToEnumMap StrToGridSamplerPaddingMode = {{"ZEROS", GridSamplerPaddingMode::ZEROS},
170                                             {"BORDER", GridSamplerPaddingMode::BORDER},
171                                             {"REFLECTION", GridSamplerPaddingMode::REFLECTION}};
172 REG_STRING_TO_ENUM_COMMON(grid_sampler_padding_mode, StrToGridSamplerPaddingMode)
173 
174 // KVCacheAlignMode
175 REG_STRING_TO_ENUM_COMMON(k_v_cache_align_mode,
176                           StrToEnumMap{{"LEFT", KVCacheAlignMode::LEFT}, {"RIGHT", KVCacheAlignMode::RIGHT}})
177 
178 REG_STRING_TO_ENUM_COMMON(fas_input_layout_mode, StrToEnumMap{{"BSH", FASInputLayoutMode::BSH},
179                                                               {"BNSD", FASInputLayoutMode::BNSD},
180                                                               {"SBH", FASInputLayoutMode::SBH},
181                                                               {"BSND", FASInputLayoutMode::BSND},
182                                                               {"TND", FASInputLayoutMode::TND}})
183 }  // namespace
184 
StringToEnumImpl(const std::string & op_name,const std::string & arg_name,const std::string & enum_string)185 int64_t StringToEnumImpl(const std::string &op_name, const std::string &arg_name, const std::string &enum_string) {
186   const auto &string_to_enum_map = reg_string_to_enum_helper.GetValues(arg_name);
187   const auto enum_val_iter = string_to_enum_map.find(StrToUpper(enum_string));
188   if (enum_val_iter == string_to_enum_map.end()) {
189     MS_EXCEPTION(ValueError) << "Failed to convert the value \"" << enum_string << "\" of input '" << arg_name
190                              << "' of '" << op_name << "' to enum.";
191   }
192   return enum_val_iter->second;
193 }
194 }  // namespace ops
195 }  // namespace mindspore
196