• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_OPS_OP_UTILS_H
18 #define MINDSPORE_CORE_OPS_OP_UTILS_H
19 #include <algorithm>
20 #include <climits>
21 #include <memory>
22 #include <utility>
23 #include <set>
24 #include <string>
25 #include <vector>
26 #include <unordered_map>
27 #include "op_name.h"
28 #include "include/api/visible.h"
29 #include "abstract/ops/primitive_infer_map.h"
30 #include "mindapi/base/shape_vector.h"
31 #include "mindapi/base/shared_ptr.h"
32 #include "mindspore/core/ops/math_ops.h"
33 
34 #ifndef MS_UNLIKELY
35 #ifdef _MSC_VER
36 #define MS_UNLIKELY(x) (x)
37 #define MS_LIKELY(x) (x)
38 #else
39 #define MS_LIKELY(x) __builtin_expect(!!(x), 1)
40 #define MS_UNLIKELY(x) __builtin_expect(!!(x), 0)
41 #endif
42 #endif
43 #define MS_CHECK_VALUE(cond, msg)        \
44   do {                                   \
45     if (MS_UNLIKELY(!(cond))) {          \
46       MS_EXCEPTION(ValueError) << (msg); \
47     }                                    \
48   } while (0)
49 
50 namespace mindspore::ops {
51 constexpr auto kBitSize = 64;
52 const std::set<TypePtr> common_valid_types = {kInt8,   kInt16,  kInt32,   kInt64,   kUInt8,   kUInt16,
53                                               kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kBFloat16};
54 // ArrayValue functions as a std::vector that verifies unknown values. ArrayValue uses std::vector<T> to hold the
55 // contents of the Sequence or Tensor flattened elements and provides an interface to determine whether each element is
56 // ValueAny.
57 template <typename T>
58 class ArrayValue {
59  public:
ArrayValue(std::vector<T> && data,std::set<size_t> && unknown_value_indexes)60   ArrayValue(std::vector<T> &&data, std::set<size_t> &&unknown_value_indexes)
61       : data_(std::move(data)), unknown_value_indexes_(std::move(unknown_value_indexes)) {}
62 
63   ArrayValue(const ArrayValue &) = default;
64   ArrayValue &operator=(const ArrayValue &) = default;
65 
ArrayValue(ArrayValue && other)66   ArrayValue(ArrayValue &&other) {
67     data_ = std::move(other.data_);
68     unknown_value_indexes_ = std::move(other.unknown_value_indexes_);
69   }
70 
71   ArrayValue &operator=(ArrayValue &&other) {
72     data_ = std::move(other.data_);
73     unknown_value_indexes_ = std::move(other.unknown_value_indexes_);
74     return *this;
75   }
76 
77   ~ArrayValue() = default;
78 
79   // Access the value of Array at the index position.
80   // Note: The value at position index can not be unknown, otherwise throw an exception.
81   const T &operator[](size_t index) const {
82     if (index >= data_.size()) {
83       MS_LOG(EXCEPTION) << "The index[" << index << "] is out of range, element size is: " << data_.size();
84     }
85     if (IsValueUnknown(index)) {
86       MS_LOG(EXCEPTION) << "Try to get unknown value.";
87     }
88     return data_[index];
89   }
90 
91   // Verify that the value at position index in ArrayValue is unknown.
IsValueUnknown(size_t index)92   bool IsValueUnknown(size_t index) const { return unknown_value_indexes_.find(index) != unknown_value_indexes_.end(); }
93 
94   // Verify whether exist unknown value in ArrayValue.
HasUnknownValue()95   bool HasUnknownValue() const { return !unknown_value_indexes_.empty(); }
96 
97   // Convert the ArrayValue to std::vector, only work when there is no unknown value in ArrayValue.
ToVector()98   const std::vector<T> &ToVector() const {
99     if (HasUnknownValue()) {
100       MS_LOG(EXCEPTION) << "Can not convert vector, there is unknown value in ArrayValue.";
101     }
102     return data_;
103   }
104 
105   // Convert the ArrayValue to a string which contains all element in ArrayValue.
ToString()106   std::string ToString() const {
107     std::ostringstream oss;
108     size_t element_size = size();
109     oss << "{ ";
110     for (size_t i = 0; i < element_size; i++) {
111       oss << (!IsValueUnknown(i) ? std::to_string(data_[i]) : "ValueUnknown");
112       if (i < element_size - 1) {
113         oss << ", ";
114       }
115     }
116     oss << " }";
117     return oss.str();
118   }
119 
120   // Get element number in ArrayValue.
size()121   size_t size() const { return data_.size(); }
122 
123  private:
124   // Use vector to hold the contents parsed from Sequence or Tensor Value.
125   std::vector<T> data_;
126   // Records the index whose value is unknown (ValueAny) in the data_ vector.
127   std::set<size_t> unknown_value_indexes_;
128 };
129 
130 // This interface is only used to get value for scalar data.
131 template <typename T>
132 MS_CORE_API std::optional<T> GetScalarValue(const ValuePtr &value);
133 
134 // This interface is only used to convert values of type Sequence or Tensor to std::vector.
135 // Input can be AbstractTensor/AbstractSequence from frontend or KernelTensor from backend.
136 template <typename T>
137 MS_CORE_API std::optional<ArrayValue<T>> GetArrayValue(const AbstractBasePtr &abs_base);
138 
139 template <typename T>
140 MS_CORE_API std::optional<ArrayValue<T>> GetArrayValue(const ValuePtr &value);
141 
142 // Get the scalar/std::string value with check
143 template <typename T, typename std::enable_if<std::is_scalar<std::decay_t<T>>::value ||
144                                               std::is_same_v<std::decay_t<T>, std::string>>::type * = nullptr>
GetValueWithCheck(const ValuePtr & value)145 T GetValueWithCheck(const ValuePtr &value) {
146   auto opt = GetScalarValue<T>(value);
147   if (!opt.has_value()) {
148     MS_LOG(EXCEPTION) << "Get scalar or string value from " << value->ToString() << " with check failed.";
149   }
150   return opt.value();
151 }
152 
153 // Template classes used to detect whether a type is a vector.
154 template <typename T>
155 struct IsVectorImpl : std::false_type {};
156 template <typename T>
157 struct IsVectorImpl<std::vector<T>> : std::true_type {};
158 template <typename T>
159 struct IsVector {
160   static constexpr bool value = IsVectorImpl<std::decay_t<T>>::value;
161 };
162 
163 // Get the std::vector value with check
164 template <typename T, typename std::enable_if<IsVector<T>::value>::type * = nullptr>
165 T GetValueWithCheck(const ValuePtr &value) {
166   auto opt = GetArrayValue<typename T::value_type>(value);
167   if (!opt.has_value()) {
168     MS_LOG(EXCEPTION) << "Get array value from " << value->ToString() << " with check failed.";
169   }
170   return opt.value().ToVector();
171 }
172 
173 const std::set<TypePtr> common_valid_types_with_bool = {
174   kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kBool, kBFloat16};
175 
176 const std::set<TypePtr> common_valid_types_with_complex = {kInt8,    kInt16,     kInt32,      kInt64,   kUInt8,
177                                                            kUInt16,  kUInt32,    kUInt64,     kFloat16, kFloat32,
178                                                            kFloat64, kComplex64, kComplex128, kBFloat16};
179 
180 const std::set<TypePtr> common_valid_types_with_complex_and_bool = {
181   kInt8,    kInt16,   kInt32,   kInt64,     kUInt8,      kUInt16, kUInt32,  kUInt64,
182   kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool,   kBFloat16};
183 
184 const std::set<TypePtr> common_integral_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
185 const std::set<TypePtr> common_float_types = {kFloat16, kFloat32, kFloat64, kBFloat16};
186 const std::set<TypePtr> all_types = {kBool,    kInt,     kInt8,    kInt16,     kInt32,      kInt64,
187                                      kUInt,    kUInt8,   kUInt16,  kUInt32,    kUInt64,     kFloat,
188                                      kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBFloat16};
189 std::vector<int64_t> CalBroadCastShape(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
190                                        const std::string &op_name, const std::string &op_x_name = "input1",
191                                        const std::string &op_y_name = "input2");
192 abstract::ShapePtr BroadCastInferShape(const std::string &op_name,
193                                        const std::vector<abstract::AbstractBasePtr> &input_args);
194 bool IsBroadcastable(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape);
195 ShapeVector BroadCastInferShape(const std::string &op_name, const ValuePtrList &input_values);
196 BaseShapePtr EltwiseGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
197 TypePtr EltwiseGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
198 TypePtrList EltwiseGradSimpleInferType(const PrimitivePtr &primitive, const ValuePtrList &input_values);
199 ShapeArray EltwiseGradSimpleInferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values);
200 void ReduceFuncCheckAxisInferImpl(const PrimitivePtr &prim, std::vector<int64_t> *axis, const size_t dim);
201 bool CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> &input_args, std::vector<int64_t> *axis_value,
202                           int64_t *axis_shape_v, const PrimitivePtr &primitive);
203 ShapeVector ReduceFuncCalShapeAxisDyn(const ShapeVector &x_shape, bool keep_dims = false);
204 ShapeVector ReduceFuncCalShapeInferImpl(const PrimitivePtr &primitive, const ShapeVector &x_shape,
205                                         const std::vector<int64_t> &axis, bool keep_dims_value = false);
206 abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
207                                         const std::vector<abstract::AbstractBasePtr> &input_args,
208                                         const std::string &prim_name);
209 TypePtr ReduceBaseInferType(const PrimitivePtr &prim, const std::vector<abstract::AbstractBasePtr> &input_args,
210                             const std::set<TypePtr> &check_list);
211 abstract::ShapePtr ReduceExtInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
212 TypePtr ReduceExtInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args);
213 
214 BaseShapePtr SetPadShape(const ShapeVector &x_shape, const ArrayValue<int64_t> &paddings);
215 BaseShapePtr PadInferShapeBase(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
216                                const size_t pad_dim);
217 
218 template <typename T>
219 api::SharedPtr<T> GetOperator(const AnfNodePtr &node) {
220   auto prim = GetValueNode<PrimitivePtr>(node);
221   if (prim == nullptr) {
222     return nullptr;
223   }
224   return api::MakeShared<T>(prim);
225 }
226 
227 bool ObscureShapeEqual(const ShapeVector &lhs, const ShapeVector &rhs);
228 
229 // Get the shape value from abstract input arg
230 // Ops like DynamicBroadcastTo or Reshape can directly get the shape value
231 // from input which represents shape by invoking this function
232 // Do not support input with type of AbstractTuple of AbstractTensor
233 ShapeVector GetShapeValue(const PrimitivePtr &primitive, const AbstractBasePtr &input_arg);
234 
235 inline ShapeVector ConvertBaseShapeToTensorShape(const BaseShapePtr &base) {
236   auto shape_ptr = base->cast<abstract::ShapePtr>();
237   MS_EXCEPTION_IF_NULL(shape_ptr);
238   return shape_ptr->shape();
239 }
240 
241 inline ShapeVector GetShapeFromTensor(const AbstractBasePtr &abs) {
242   auto base_shape = abs->GetShape();
243   return ConvertBaseShapeToTensorShape(base_shape);
244 }
245 
246 void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp);
247 
248 void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name);
249 
250 void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name);
251 
252 void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name);
253 
254 inline void CheckInputShapeEmpty(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args) {
255   for (size_t i = 0; i < input_args.size(); ++i) {
256     MS_EXCEPTION_IF_NULL(input_args[i]->GetShape());
257     if (input_args[i]->GetShape()->IsDimZero()) {
258       MS_LOG(EXCEPTION) << "For '" << prim_name << "', input " << i << "'s shape should not be empty!";
259     }
260   }
261 }
262 
263 ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape);
264 
265 template <typename T>
266 std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list);
267 
268 template <typename T>
269 AbstractBasePtr TensorToSequenceInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
270 
271 template <typename T>
272 AbstractBasePtr InferSequenceSetItem(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list);
273 
274 template <typename T>
275 T GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
276 
277 TypePtr HighPriorityType(const TypePtr &x_type, const TypePtr &y_type, const std::string &op_name);
278 
279 inline bool IsValueKnown(const ValuePtr &value) {
280   MS_EXCEPTION_IF_NULL(value);
281   return !value->isa<ValueAny>() && !value->isa<None>();
282 }
283 
284 inline bool IsValueKnown(const AbstractBasePtr &abs) {
285   MS_EXCEPTION_IF_NULL(abs);
286   return IsValueKnown(abs->GetValue());
287 }
288 
289 MS_CORE_API size_t GetInputIndexByName(const std::string &op_name, const std::string &input_name);
290 MS_CORE_API std::string GetInputNameByIndex(const std::string &op_name, size_t index);
291 MS_CORE_API size_t GetOpInputsNum(const std::string &op_name);
292 MS_CORE_API std::set<int64_t> GetInputDependValueList(const PrimitivePtr &op_prim);
293 MS_CORE_API CNodePtr ConvertArgsToAttr(const CNodePtr &cnode);
294 MS_CORE_API bool HasOpDef(const std::string &op_name);
295 
296 constexpr auto kCSRAvgRows = "csr_avg_rows";
297 constexpr auto kIsCSR = "is_csr";
298 constexpr auto kCSRDenseShape = "dense_shape";
299 constexpr auto kCSRAxis = "axis";
300 constexpr auto kHasDynamicValue = "has_dynamic_value";
301 
302 inline int64_t get_batch_rank(const PrimitivePtr &prim) {
303   if (prim->HasAttr(kBatchRank)) {
304     auto value_ptr = prim->GetAttr(kBatchRank);
305     return GetValue<int64_t>(value_ptr);
306   }
307   return 0;
308 }
309 
310 inline int64_t PadModeStringToInt(const std::string &pad) {
311   std::string pad_mode = pad;
312   (void)std::transform(pad_mode.begin(), pad_mode.end(), pad_mode.begin(), toupper);
313   if (pad_mode == "VALID") {
314     return static_cast<int64_t>(2);
315   } else if (pad_mode == "SAME") {
316     return static_cast<int64_t>(1);
317   } else if (pad_mode == "PAD") {
318     return static_cast<int64_t>(0);
319   } else if (pad_mode == "CALCULATED") {
320     return static_cast<int64_t>(0);
321   } else {
322     MS_LOG(EXCEPTION) << "Got an invalid pad_mode string: " << pad_mode << ".";
323   }
324 }
325 
326 static inline TypePtr PromoteType(TypePtr a, TypePtr b, const std::string &op_name) {
327   const auto f32 = kNumberTypeFloat32;
328   const auto f16 = kNumberTypeFloat16;
329   const auto f64 = kNumberTypeFloat64;
330   const auto bf16 = kNumberTypeBFloat16;
331   const auto s8 = kNumberTypeInt8;
332   const auto u8 = kNumberTypeUInt8;
333   const auto s16 = kNumberTypeInt16;
334   const auto u16 = kNumberTypeUInt16;
335   const auto s32 = kNumberTypeInt32;
336   const auto u32 = kNumberTypeUInt32;
337   const auto s64 = kNumberTypeInt64;
338   const auto u64 = kNumberTypeUInt64;
339   const auto b1 = kNumberTypeBool;
340   const auto c64 = kNumberTypeComplex64;
341   const auto c128 = kNumberTypeComplex128;
342   const auto ud = kTypeUnknown;
343 
344   static std::unordered_map<TypeId, size_t> typeid_idx = {{f32, 0},  {f16, 1},  {f64, 2}, {bf16, 3}, {s8, 4},
345                                                           {u8, 5},   {s16, 6},  {u16, 7}, {s32, 8},  {u32, 9},
346                                                           {s64, 10}, {u64, 11}, {b1, 12}, {c64, 13}, {c128, 14}};
347   static std::unordered_map<TypeId, TypePtr> typeid_typeptr = {
348     {f32, kFloat32}, {f16, kFloat16}, {f64, kFloat64}, {bf16, kBFloat16}, {s8, kInt8},
349     {u8, kUInt8},    {s16, kInt16},   {u16, kUInt16},  {s32, kInt32},     {u32, kUInt32},
350     {s64, kInt64},   {u64, kUInt64},  {b1, kBool},     {c64, kComplex64}, {c128, kComplex128}};
351 
352   auto a_tensor_type = a->cast<TensorTypePtr>();
353   MS_EXCEPTION_IF_NULL(a_tensor_type);
354   auto a_element = a_tensor_type->element();
355   MS_EXCEPTION_IF_NULL(a_element);
356   const TypeId &a_type_id = a_element->type_id();
357 
358   auto b_tensor_type = b->cast<TensorTypePtr>();
359   MS_EXCEPTION_IF_NULL(b_tensor_type);
360   auto b_element = b_tensor_type->element();
361   MS_EXCEPTION_IF_NULL(b_element);
362   const TypeId &b_type_id = b_element->type_id();
363 
364   if (typeid_idx.find(a_type_id) == typeid_idx.end()) {
365     MS_EXCEPTION(TypeError) << "For Op[" << op_name << "], the type " << a->ToString() << "is invalid";
366   }
367 
368   if (typeid_idx.find(b_type_id) == typeid_idx.end()) {
369     MS_EXCEPTION(TypeError) << "For Op[" << op_name << "], the type " << b->ToString() << "is invalid";
370   }
371 
372   if (a_type_id == b_type_id) {
373     return a->Clone();
374   }
375 
376   static const std::vector<std::vector<TypeId>> promote_types_lookup = {
377     /*         f32  f16  f64  bf16  s8  u8  s16  u16  s32  u32  s64  u64  b1 c64  c128 */
378     /* f32 */ {f32, f32, f64, f32, f32, f32, f32, ud, f32, ud, f32, ud, f32, c64, c128},
379     /* f16 */ {f32, f16, f64, f32, f16, f16, f16, ud, f16, ud, f16, ud, f16, c64, c128},
380     /* f64 */ {f64, f64, f64, f64, f64, f64, f64, ud, f64, ud, f64, ud, f64, c64, c128},
381     /* bf16*/ {f32, f64, f64, bf16, bf16, bf16, bf16, ud, bf16, ud, bf16, ud, bf16, c64, c128},
382     /* s8  */ {f32, f16, f64, bf16, s8, s16, s16, ud, s32, ud, s64, ud, s8, c64, c128},
383     /* u8  */ {f32, f16, f64, bf16, s16, u8, s16, ud, s32, ud, s64, ud, u8, c64, c128},
384     /* s16 */ {f32, f16, f64, bf16, s16, s16, s16, ud, s32, ud, s64, ud, s16, c64, c128},
385     /* u16 */ {ud, ud, ud, ud, ud, ud, ud, u16, ud, ud, ud, ud, ud, ud, ud},
386     /* s32 */ {f32, f16, f64, bf16, s32, s32, s32, ud, s32, ud, s64, ud, s32, c64, c128},
387     /* u32 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, u32, ud, ud, ud, ud, ud},
388     /* s64 */ {f32, f16, f64, bf16, s64, s64, s64, ud, s64, ud, s64, ud, s64, c64, c128},
389     /* u64 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, u64, ud, ud, ud},
390     /* b1  */ {f32, f16, f64, bf16, s8, u8, s16, ud, s32, ud, s64, ud, b1, c64, c128},
391     /* c64 */ {c64, c64, c64, c64, c64, c64, c64, ud, c64, ud, c64, ud, c64, c64, c128},
392     /* c128*/ {c128, c128, c128, c128, c128, c128, c128, ud, c128, ud, c128, ud, c128, c128, c128},
393   };
394 
395   auto return_type_id = promote_types_lookup[typeid_idx[a_type_id]][typeid_idx[b_type_id]];
396 
397   if (return_type_id == ud) {
398     MS_EXCEPTION(TypeError) << "For Op[" << op_name << "], the promote output type is invalid";
399   }
400 
401   return std::make_shared<TensorType>(typeid_typeptr[return_type_id]);
402 }
403 
404 void CheckTensorScalarRank(const PrimitivePtr &primitive, const AbstractBasePtr input_arg, const std::string &arg_name);
405 }  // namespace mindspore::ops
406 #endif  // MINDSPORE_CORE_OPS_OP_UTILS_H
407