1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_OPS_OP_UTILS_H
18 #define MINDSPORE_CORE_OPS_OP_UTILS_H
19 #include <algorithm>
20 #include <climits>
21 #include <memory>
22 #include <utility>
23 #include <set>
24 #include <string>
25 #include <vector>
26 #include <unordered_map>
27 #include "op_name.h"
28 #include "include/api/visible.h"
29 #include "abstract/ops/primitive_infer_map.h"
30 #include "mindapi/base/shape_vector.h"
31 #include "mindapi/base/shared_ptr.h"
32 #include "mindspore/core/ops/math_ops.h"
33
34 #ifndef MS_UNLIKELY
35 #ifdef _MSC_VER
36 #define MS_UNLIKELY(x) (x)
37 #define MS_LIKELY(x) (x)
38 #else
39 #define MS_LIKELY(x) __builtin_expect(!!(x), 1)
40 #define MS_UNLIKELY(x) __builtin_expect(!!(x), 0)
41 #endif
42 #endif
43 #define MS_CHECK_VALUE(cond, msg) \
44 do { \
45 if (MS_UNLIKELY(!(cond))) { \
46 MS_EXCEPTION(ValueError) << (msg); \
47 } \
48 } while (0)
49
50 namespace mindspore::ops {
51 constexpr auto kBitSize = 64;
52 const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
53 kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kBFloat16};
54 // ArrayValue functions as a std::vector that verifies unknown values. ArrayValue uses std::vector<T> to hold the
55 // contents of the Sequence or Tensor flattened elements and provides an interface to determine whether each element is
56 // ValueAny.
57 template <typename T>
58 class ArrayValue {
59 public:
ArrayValue(std::vector<T> && data,std::set<size_t> && unknown_value_indexes)60 ArrayValue(std::vector<T> &&data, std::set<size_t> &&unknown_value_indexes)
61 : data_(std::move(data)), unknown_value_indexes_(std::move(unknown_value_indexes)) {}
62
63 ArrayValue(const ArrayValue &) = default;
64 ArrayValue &operator=(const ArrayValue &) = default;
65
ArrayValue(ArrayValue && other)66 ArrayValue(ArrayValue &&other) {
67 data_ = std::move(other.data_);
68 unknown_value_indexes_ = std::move(other.unknown_value_indexes_);
69 }
70
71 ArrayValue &operator=(ArrayValue &&other) {
72 data_ = std::move(other.data_);
73 unknown_value_indexes_ = std::move(other.unknown_value_indexes_);
74 return *this;
75 }
76
77 ~ArrayValue() = default;
78
79 // Access the value of Array at the index position.
80 // Note: The value at position index can not be unknown, otherwise throw an exception.
81 const T &operator[](size_t index) const {
82 if (index >= data_.size()) {
83 MS_LOG(EXCEPTION) << "The index[" << index << "] is out of range, element size is: " << data_.size();
84 }
85 if (IsValueUnknown(index)) {
86 MS_LOG(EXCEPTION) << "Try to get unknown value.";
87 }
88 return data_[index];
89 }
90
91 // Verify that the value at position index in ArrayValue is unknown.
IsValueUnknown(size_t index)92 bool IsValueUnknown(size_t index) const { return unknown_value_indexes_.find(index) != unknown_value_indexes_.end(); }
93
94 // Verify whether exist unknown value in ArrayValue.
HasUnknownValue()95 bool HasUnknownValue() const { return !unknown_value_indexes_.empty(); }
96
97 // Convert the ArrayValue to std::vector, only work when there is no unknown value in ArrayValue.
ToVector()98 const std::vector<T> &ToVector() const {
99 if (HasUnknownValue()) {
100 MS_LOG(EXCEPTION) << "Can not convert vector, there is unknown value in ArrayValue.";
101 }
102 return data_;
103 }
104
105 // Convert the ArrayValue to a string which contains all element in ArrayValue.
ToString()106 std::string ToString() const {
107 std::ostringstream oss;
108 size_t element_size = size();
109 oss << "{ ";
110 for (size_t i = 0; i < element_size; i++) {
111 oss << (!IsValueUnknown(i) ? std::to_string(data_[i]) : "ValueUnknown");
112 if (i < element_size - 1) {
113 oss << ", ";
114 }
115 }
116 oss << " }";
117 return oss.str();
118 }
119
120 // Get element number in ArrayValue.
size()121 size_t size() const { return data_.size(); }
122
123 private:
124 // Use vector to hold the contents parsed from Sequence or Tensor Value.
125 std::vector<T> data_;
126 // Records the index whose value is unknown (ValueAny) in the data_ vector.
127 std::set<size_t> unknown_value_indexes_;
128 };
129
130 // This interface is only used to get value for scalar data.
131 template <typename T>
132 MS_CORE_API std::optional<T> GetScalarValue(const ValuePtr &value);
133
134 // This interface is only used to convert values of type Sequence or Tensor to std::vector.
135 // Input can be AbstractTensor/AbstractSequence from frontend or KernelTensor from backend.
136 template <typename T>
137 MS_CORE_API std::optional<ArrayValue<T>> GetArrayValue(const AbstractBasePtr &abs_base);
138
139 template <typename T>
140 MS_CORE_API std::optional<ArrayValue<T>> GetArrayValue(const ValuePtr &value);
141
142 // Get the scalar/std::string value with check
143 template <typename T, typename std::enable_if<std::is_scalar<std::decay_t<T>>::value ||
144 std::is_same_v<std::decay_t<T>, std::string>>::type * = nullptr>
GetValueWithCheck(const ValuePtr & value)145 T GetValueWithCheck(const ValuePtr &value) {
146 auto opt = GetScalarValue<T>(value);
147 if (!opt.has_value()) {
148 MS_LOG(EXCEPTION) << "Get scalar or string value from " << value->ToString() << " with check failed.";
149 }
150 return opt.value();
151 }
152
153 // Template classes used to detect whether a type is a vector.
154 template <typename T>
155 struct IsVectorImpl : std::false_type {};
156 template <typename T>
157 struct IsVectorImpl<std::vector<T>> : std::true_type {};
158 template <typename T>
159 struct IsVector {
160 static constexpr bool value = IsVectorImpl<std::decay_t<T>>::value;
161 };
162
163 // Get the std::vector value with check
164 template <typename T, typename std::enable_if<IsVector<T>::value>::type * = nullptr>
165 T GetValueWithCheck(const ValuePtr &value) {
166 auto opt = GetArrayValue<typename T::value_type>(value);
167 if (!opt.has_value()) {
168 MS_LOG(EXCEPTION) << "Get array value from " << value->ToString() << " with check failed.";
169 }
170 return opt.value().ToVector();
171 }
172
173 const std::set<TypePtr> common_valid_types_with_bool = {
174 kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kBool, kBFloat16};
175
176 const std::set<TypePtr> common_valid_types_with_complex = {kInt8, kInt16, kInt32, kInt64, kUInt8,
177 kUInt16, kUInt32, kUInt64, kFloat16, kFloat32,
178 kFloat64, kComplex64, kComplex128, kBFloat16};
179
180 const std::set<TypePtr> common_valid_types_with_complex_and_bool = {
181 kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64,
182 kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool, kBFloat16};
183
184 const std::set<TypePtr> common_integral_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
185 const std::set<TypePtr> common_float_types = {kFloat16, kFloat32, kFloat64, kBFloat16};
186 const std::set<TypePtr> all_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64,
187 kUInt, kUInt8, kUInt16, kUInt32, kUInt64, kFloat,
188 kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBFloat16};
189 std::vector<int64_t> CalBroadCastShape(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
190 const std::string &op_name, const std::string &op_x_name = "input1",
191 const std::string &op_y_name = "input2");
192 abstract::ShapePtr BroadCastInferShape(const std::string &op_name,
193 const std::vector<abstract::AbstractBasePtr> &input_args);
194 bool IsBroadcastable(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape);
195 ShapeVector BroadCastInferShape(const std::string &op_name, const ValuePtrList &input_values);
196 BaseShapePtr EltwiseGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
197 TypePtr EltwiseGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
198 TypePtrList EltwiseGradSimpleInferType(const PrimitivePtr &primitive, const ValuePtrList &input_values);
199 ShapeArray EltwiseGradSimpleInferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values);
200 void ReduceFuncCheckAxisInferImpl(const PrimitivePtr &prim, std::vector<int64_t> *axis, const size_t dim);
201 bool CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> &input_args, std::vector<int64_t> *axis_value,
202 int64_t *axis_shape_v, const PrimitivePtr &primitive);
203 ShapeVector ReduceFuncCalShapeAxisDyn(const ShapeVector &x_shape, bool keep_dims = false);
204 ShapeVector ReduceFuncCalShapeInferImpl(const PrimitivePtr &primitive, const ShapeVector &x_shape,
205 const std::vector<int64_t> &axis, bool keep_dims_value = false);
206 abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
207 const std::vector<abstract::AbstractBasePtr> &input_args,
208 const std::string &prim_name);
209 TypePtr ReduceBaseInferType(const PrimitivePtr &prim, const std::vector<abstract::AbstractBasePtr> &input_args,
210 const std::set<TypePtr> &check_list);
211 abstract::ShapePtr ReduceExtInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
212 TypePtr ReduceExtInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args);
213
214 BaseShapePtr SetPadShape(const ShapeVector &x_shape, const ArrayValue<int64_t> &paddings);
215 BaseShapePtr PadInferShapeBase(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
216 const size_t pad_dim);
217
218 template <typename T>
219 api::SharedPtr<T> GetOperator(const AnfNodePtr &node) {
220 auto prim = GetValueNode<PrimitivePtr>(node);
221 if (prim == nullptr) {
222 return nullptr;
223 }
224 return api::MakeShared<T>(prim);
225 }
226
227 bool ObscureShapeEqual(const ShapeVector &lhs, const ShapeVector &rhs);
228
229 // Get the shape value from abstract input arg
230 // Ops like DynamicBroadcastTo or Reshape can directly get the shape value
231 // from input which represents shape by invoking this function
232 // Do not support input with type of AbstractTuple of AbstractTensor
233 ShapeVector GetShapeValue(const PrimitivePtr &primitive, const AbstractBasePtr &input_arg);
234
235 inline ShapeVector ConvertBaseShapeToTensorShape(const BaseShapePtr &base) {
236 auto shape_ptr = base->cast<abstract::ShapePtr>();
237 MS_EXCEPTION_IF_NULL(shape_ptr);
238 return shape_ptr->shape();
239 }
240
241 inline ShapeVector GetShapeFromTensor(const AbstractBasePtr &abs) {
242 auto base_shape = abs->GetShape();
243 return ConvertBaseShapeToTensorShape(base_shape);
244 }
245
246 void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp);
247
248 void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name);
249
250 void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name);
251
252 void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name);
253
254 inline void CheckInputShapeEmpty(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args) {
255 for (size_t i = 0; i < input_args.size(); ++i) {
256 MS_EXCEPTION_IF_NULL(input_args[i]->GetShape());
257 if (input_args[i]->GetShape()->IsDimZero()) {
258 MS_LOG(EXCEPTION) << "For '" << prim_name << "', input " << i << "'s shape should not be empty!";
259 }
260 }
261 }
262
263 ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape);
264
265 template <typename T>
266 std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list);
267
268 template <typename T>
269 AbstractBasePtr TensorToSequenceInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
270
271 template <typename T>
272 AbstractBasePtr InferSequenceSetItem(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list);
273
274 template <typename T>
275 T GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
276
277 TypePtr HighPriorityType(const TypePtr &x_type, const TypePtr &y_type, const std::string &op_name);
278
279 inline bool IsValueKnown(const ValuePtr &value) {
280 MS_EXCEPTION_IF_NULL(value);
281 return !value->isa<ValueAny>() && !value->isa<None>();
282 }
283
284 inline bool IsValueKnown(const AbstractBasePtr &abs) {
285 MS_EXCEPTION_IF_NULL(abs);
286 return IsValueKnown(abs->GetValue());
287 }
288
289 MS_CORE_API size_t GetInputIndexByName(const std::string &op_name, const std::string &input_name);
290 MS_CORE_API std::string GetInputNameByIndex(const std::string &op_name, size_t index);
291 MS_CORE_API size_t GetOpInputsNum(const std::string &op_name);
292 MS_CORE_API std::set<int64_t> GetInputDependValueList(const PrimitivePtr &op_prim);
293 MS_CORE_API CNodePtr ConvertArgsToAttr(const CNodePtr &cnode);
294 MS_CORE_API bool HasOpDef(const std::string &op_name);
295
296 constexpr auto kCSRAvgRows = "csr_avg_rows";
297 constexpr auto kIsCSR = "is_csr";
298 constexpr auto kCSRDenseShape = "dense_shape";
299 constexpr auto kCSRAxis = "axis";
300 constexpr auto kHasDynamicValue = "has_dynamic_value";
301
302 inline int64_t get_batch_rank(const PrimitivePtr &prim) {
303 if (prim->HasAttr(kBatchRank)) {
304 auto value_ptr = prim->GetAttr(kBatchRank);
305 return GetValue<int64_t>(value_ptr);
306 }
307 return 0;
308 }
309
310 inline int64_t PadModeStringToInt(const std::string &pad) {
311 std::string pad_mode = pad;
312 (void)std::transform(pad_mode.begin(), pad_mode.end(), pad_mode.begin(), toupper);
313 if (pad_mode == "VALID") {
314 return static_cast<int64_t>(2);
315 } else if (pad_mode == "SAME") {
316 return static_cast<int64_t>(1);
317 } else if (pad_mode == "PAD") {
318 return static_cast<int64_t>(0);
319 } else if (pad_mode == "CALCULATED") {
320 return static_cast<int64_t>(0);
321 } else {
322 MS_LOG(EXCEPTION) << "Got an invalid pad_mode string: " << pad_mode << ".";
323 }
324 }
325
326 static inline TypePtr PromoteType(TypePtr a, TypePtr b, const std::string &op_name) {
327 const auto f32 = kNumberTypeFloat32;
328 const auto f16 = kNumberTypeFloat16;
329 const auto f64 = kNumberTypeFloat64;
330 const auto bf16 = kNumberTypeBFloat16;
331 const auto s8 = kNumberTypeInt8;
332 const auto u8 = kNumberTypeUInt8;
333 const auto s16 = kNumberTypeInt16;
334 const auto u16 = kNumberTypeUInt16;
335 const auto s32 = kNumberTypeInt32;
336 const auto u32 = kNumberTypeUInt32;
337 const auto s64 = kNumberTypeInt64;
338 const auto u64 = kNumberTypeUInt64;
339 const auto b1 = kNumberTypeBool;
340 const auto c64 = kNumberTypeComplex64;
341 const auto c128 = kNumberTypeComplex128;
342 const auto ud = kTypeUnknown;
343
344 static std::unordered_map<TypeId, size_t> typeid_idx = {{f32, 0}, {f16, 1}, {f64, 2}, {bf16, 3}, {s8, 4},
345 {u8, 5}, {s16, 6}, {u16, 7}, {s32, 8}, {u32, 9},
346 {s64, 10}, {u64, 11}, {b1, 12}, {c64, 13}, {c128, 14}};
347 static std::unordered_map<TypeId, TypePtr> typeid_typeptr = {
348 {f32, kFloat32}, {f16, kFloat16}, {f64, kFloat64}, {bf16, kBFloat16}, {s8, kInt8},
349 {u8, kUInt8}, {s16, kInt16}, {u16, kUInt16}, {s32, kInt32}, {u32, kUInt32},
350 {s64, kInt64}, {u64, kUInt64}, {b1, kBool}, {c64, kComplex64}, {c128, kComplex128}};
351
352 auto a_tensor_type = a->cast<TensorTypePtr>();
353 MS_EXCEPTION_IF_NULL(a_tensor_type);
354 auto a_element = a_tensor_type->element();
355 MS_EXCEPTION_IF_NULL(a_element);
356 const TypeId &a_type_id = a_element->type_id();
357
358 auto b_tensor_type = b->cast<TensorTypePtr>();
359 MS_EXCEPTION_IF_NULL(b_tensor_type);
360 auto b_element = b_tensor_type->element();
361 MS_EXCEPTION_IF_NULL(b_element);
362 const TypeId &b_type_id = b_element->type_id();
363
364 if (typeid_idx.find(a_type_id) == typeid_idx.end()) {
365 MS_EXCEPTION(TypeError) << "For Op[" << op_name << "], the type " << a->ToString() << "is invalid";
366 }
367
368 if (typeid_idx.find(b_type_id) == typeid_idx.end()) {
369 MS_EXCEPTION(TypeError) << "For Op[" << op_name << "], the type " << b->ToString() << "is invalid";
370 }
371
372 if (a_type_id == b_type_id) {
373 return a->Clone();
374 }
375
376 static const std::vector<std::vector<TypeId>> promote_types_lookup = {
377 /* f32 f16 f64 bf16 s8 u8 s16 u16 s32 u32 s64 u64 b1 c64 c128 */
378 /* f32 */ {f32, f32, f64, f32, f32, f32, f32, ud, f32, ud, f32, ud, f32, c64, c128},
379 /* f16 */ {f32, f16, f64, f32, f16, f16, f16, ud, f16, ud, f16, ud, f16, c64, c128},
380 /* f64 */ {f64, f64, f64, f64, f64, f64, f64, ud, f64, ud, f64, ud, f64, c64, c128},
381 /* bf16*/ {f32, f64, f64, bf16, bf16, bf16, bf16, ud, bf16, ud, bf16, ud, bf16, c64, c128},
382 /* s8 */ {f32, f16, f64, bf16, s8, s16, s16, ud, s32, ud, s64, ud, s8, c64, c128},
383 /* u8 */ {f32, f16, f64, bf16, s16, u8, s16, ud, s32, ud, s64, ud, u8, c64, c128},
384 /* s16 */ {f32, f16, f64, bf16, s16, s16, s16, ud, s32, ud, s64, ud, s16, c64, c128},
385 /* u16 */ {ud, ud, ud, ud, ud, ud, ud, u16, ud, ud, ud, ud, ud, ud, ud},
386 /* s32 */ {f32, f16, f64, bf16, s32, s32, s32, ud, s32, ud, s64, ud, s32, c64, c128},
387 /* u32 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, u32, ud, ud, ud, ud, ud},
388 /* s64 */ {f32, f16, f64, bf16, s64, s64, s64, ud, s64, ud, s64, ud, s64, c64, c128},
389 /* u64 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, u64, ud, ud, ud},
390 /* b1 */ {f32, f16, f64, bf16, s8, u8, s16, ud, s32, ud, s64, ud, b1, c64, c128},
391 /* c64 */ {c64, c64, c64, c64, c64, c64, c64, ud, c64, ud, c64, ud, c64, c64, c128},
392 /* c128*/ {c128, c128, c128, c128, c128, c128, c128, ud, c128, ud, c128, ud, c128, c128, c128},
393 };
394
395 auto return_type_id = promote_types_lookup[typeid_idx[a_type_id]][typeid_idx[b_type_id]];
396
397 if (return_type_id == ud) {
398 MS_EXCEPTION(TypeError) << "For Op[" << op_name << "], the promote output type is invalid";
399 }
400
401 return std::make_shared<TensorType>(typeid_typeptr[return_type_id]);
402 }
403
404 void CheckTensorScalarRank(const PrimitivePtr &primitive, const AbstractBasePtr input_arg, const std::string &arg_name);
405 } // namespace mindspore::ops
406 #endif // MINDSPORE_CORE_OPS_OP_UTILS_H
407