1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/operator/cc_implementations.h"
18 #include <limits>
19 #include <algorithm>
20 #include <cmath>
21 #include <cfloat>
22 #include "utils/log_adapter.h"
23 #include "utils/convert_utils.h"
24 #include "utils/ms_utils.h"
25
26 namespace mindspore {
27 // namespace to support primitive operators definition
28 namespace prim {
29 enum class DataType { kInt, kInt64, kFloat, kDouble, kUnknown };
30
31 // Whether has a T type data in AnyPtrList.
32 template <class T>
HasType(const AnyPtrList & list)33 bool HasType(const AnyPtrList &list) {
34 bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
35 return ret;
36 }
37
InferType(const AnyPtrList & list)38 DataType InferType(const AnyPtrList &list) {
39 if (HasType<double>(list)) {
40 return DataType::kDouble;
41 } else if (HasType<float>(list)) {
42 return DataType::kFloat;
43 } else if (HasType<int64_t>(list)) {
44 return DataType::kInt64;
45 } else if (HasType<int>(list)) {
46 return DataType::kInt;
47 }
48 return DataType::kUnknown;
49 }
50
51 template <typename T>
IsAddOverflow(const T & x,const T & y,const T & max,const T & min)52 bool IsAddOverflow(const T &x, const T &y, const T &max, const T &min) {
53 return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
54 }
55
56 template <typename T>
IsSubOverflow(const T & x,const T & y,const T & max,const T & min)57 bool IsSubOverflow(const T &x, const T &y, const T &max, const T &min) {
58 return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
59 }
60
61 template <typename T>
IsMulOverflow(const T & x,const T & y,const T & max,const T & min)62 bool IsMulOverflow(const T &x, const T &y, const T &max, const T &min) {
63 return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || (x > 0 && y < 0 && (min / y) < x) ||
64 (x < 0 && y > 0 && (min / y) > x);
65 }
66
67 template <typename T>
IsDivOverflow(const T & x,const T & y,const T & min)68 bool IsDivOverflow(const T &x, const T &y, const T &min) {
69 return (x == min && static_cast<int64_t>(y) == -1);
70 }
71
72 enum class OpType { ADD, SUB, MUL, DIV, MOD };
73
74 template <typename T>
IsSignedIntOverflow(T x,T y,OpType opType)75 bool IsSignedIntOverflow(T x, T y, OpType opType) {
76 auto max = std::numeric_limits<T>::max();
77 auto min = std::numeric_limits<T>::min();
78
79 if (opType == OpType::ADD) {
80 return IsAddOverflow<T>(x, y, max, min);
81 }
82
83 if (opType == OpType::SUB) {
84 return IsSubOverflow<T>(x, y, max, min);
85 }
86
87 if (opType == OpType::MUL) {
88 return IsMulOverflow<T>(x, y, max, min);
89 }
90
91 if (opType == OpType::DIV || opType == OpType::MOD) {
92 return IsDivOverflow<T>(x, y, min);
93 }
94
95 MS_EXCEPTION(NotSupportError) << "Unsupported operation type.";
96 }
97
98 template <typename T>
InnerScalarAdd(T x,T y)99 T InnerScalarAdd(T x, T y) {
100 if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::ADD)) {
101 MS_EXCEPTION(ValueError) << "Overflow of the sum of two signed number x: " << std::to_string(x)
102 << ", y: " << std::to_string(y) << ".";
103 }
104 return x + y;
105 }
106
107 template <typename T>
InnerScalarSub(T x,T y)108 T InnerScalarSub(T x, T y) {
109 if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::SUB)) {
110 MS_EXCEPTION(ValueError) << "Overflow of the sub of two signed number x: " << std::to_string(x)
111 << ", y: " << std::to_string(y) << ".";
112 }
113 return x - y;
114 }
115
116 template <typename T>
InnerScalarMul(T x,T y)117 T InnerScalarMul(T x, T y) {
118 if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MUL)) {
119 MS_EXCEPTION(ValueError) << "Overflow of the mul of two signed number x: " << std::to_string(x)
120 << ", y: " << std::to_string(y) << ".";
121 }
122 return x * y;
123 }
124
125 template <typename T>
InnerScalarDiv(T x,T y)126 float InnerScalarDiv(T x, T y) {
127 if (y == 0) {
128 MS_EXCEPTION(ValueError) << "Divisor could not be zero";
129 }
130 if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::DIV)) {
131 MS_EXCEPTION(ValueError) << "Overflow of the div of two signed number x: " << std::to_string(x)
132 << ", y: " << std::to_string(y) << ".";
133 }
134 return static_cast<float>(x) / static_cast<float>(y);
135 }
136
137 template <typename T>
InnerScalarFloordiv(T x,T y)138 T InnerScalarFloordiv(T x, T y) {
139 auto ret = std::floor(InnerScalarDiv(x, y));
140 if (std::is_integral<T>::value) {
141 return static_cast<int64_t>(ret);
142 }
143 return ret;
144 }
145
146 template <typename T>
InnerScalarMod(T x,T y)147 T InnerScalarMod(T x, T y) {
148 if (y == 0) {
149 MS_EXCEPTION(ValueError) << "Could not mod to zero.";
150 }
151 if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MOD)) {
152 MS_EXCEPTION(ValueError) << "Overflow of the mod of two signed number x: " << std::to_string(x)
153 << ", y: " << std::to_string(y) << ".";
154 }
155 if (std::is_integral<T>::value) {
156 return static_cast<int64_t>(x) % static_cast<int64_t>(y);
157 }
158 return x - y * std::floor(x / y);
159 }
160
161 template <typename T, typename U>
InnerScalarPow(T x,U y)162 T InnerScalarPow(T x, U y) {
163 return std::pow(x, y);
164 }
165
166 template <typename T, typename U>
InnerScalarEq(T x,U y)167 bool InnerScalarEq(T x, U y) {
168 double error = static_cast<double>(x) - static_cast<double>(y);
169 error = fabs(error);
170 return error < DBL_EPSILON;
171 }
172
173 template <typename T, typename U>
InnerScalarLt(T x,U y)174 bool InnerScalarLt(T x, U y) {
175 return x < y;
176 }
177
178 template <typename T, typename U>
InnerScalarGt(T x,U y)179 bool InnerScalarGt(T x, U y) {
180 return x > y;
181 }
182
183 template <typename T, typename U>
InnerScalarNe(T x,U y)184 bool InnerScalarNe(T x, U y) {
185 return !InnerScalarEq(x, y);
186 }
187
188 template <typename T, typename U>
InnerScalarLe(T x,U y)189 bool InnerScalarLe(T x, U y) {
190 return x <= y;
191 }
192
193 template <typename T, typename U>
InnerScalarGe(T x,U y)194 bool InnerScalarGe(T x, U y) {
195 return x >= y;
196 }
197
198 #define SCALAR_OP(op_t) \
199 ValuePtr Scalar##op_t(const ValuePtrList &list) { \
200 do { \
201 if (list.size() != 2) { \
202 MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be 2, but got " << list.size(); \
203 } \
204 ValuePtr x = list[0]; \
205 ValuePtr y = list[1]; \
206 MS_EXCEPTION_IF_NULL(x); \
207 MS_EXCEPTION_IF_NULL(y); \
208 if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
209 double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
210 return MakeValue(sum); \
211 } \
212 if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
213 float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
214 return MakeValue(sum); \
215 } \
216 if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
217 int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
218 return MakeValue(sum); \
219 } \
220 if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
221 float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y)); \
222 return MakeValue(sum); \
223 } \
224 if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
225 float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y))); \
226 return MakeValue(sum); \
227 } \
228 if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
229 int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
230 return MakeValue(sum); \
231 } \
232 if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
233 double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), GetValue<double>(y)); \
234 return MakeValue(sum); \
235 } \
236 if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
237 double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
238 return MakeValue(sum); \
239 } \
240 if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
241 int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
242 return MakeValue(sum); \
243 } \
244 if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
245 double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
246 return MakeValue(sum); \
247 } \
248 if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
249 double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
250 return MakeValue(sum); \
251 } \
252 if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
253 int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
254 return MakeValue(sum); \
255 } \
256 MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
257 << ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
258 << ", value of y:" << y->ToString(); \
259 } while (0); \
260 }
261
262 SCALAR_OP(Add)
SCALAR_OP(Sub)263 SCALAR_OP(Sub)
264 SCALAR_OP(Mul)
265 SCALAR_OP(Div)
266 SCALAR_OP(Mod)
267 SCALAR_OP(Pow)
268 SCALAR_OP(Floordiv)
269
270 #define LOGIC_OP(op_t) \
271 ValuePtr Scalar##op_t(const ValuePtrList &list) { \
272 if (list.size() != 2) { \
273 MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be 2, but got " << list.size(); \
274 } \
275 ValuePtr x = list[0]; \
276 ValuePtr y = list[1]; \
277 MS_EXCEPTION_IF_NULL(x); \
278 MS_EXCEPTION_IF_NULL(y); \
279 if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
280 bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
281 return MakeValue(sum); \
282 } \
283 if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
284 bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
285 return MakeValue(sum); \
286 } \
287 if (x->isa<FP64Imm>() && y->isa<FP32Imm>()) { \
288 bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<float>(y)); \
289 return MakeValue(sum); \
290 } \
291 if (x->isa<FP32Imm>() && y->isa<FP64Imm>()) { \
292 bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<double>(y)); \
293 return MakeValue(sum); \
294 } \
295 if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
296 bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
297 return MakeValue(sum); \
298 } \
299 if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
300 bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
301 return MakeValue(sum); \
302 } \
303 if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
304 bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int64_t>(y)); \
305 return MakeValue(sum); \
306 } \
307 if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
308 bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \
309 return MakeValue(sum); \
310 } \
311 if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
312 bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<float>(y)); \
313 return MakeValue(sum); \
314 } \
315 if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
316 bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
317 return MakeValue(sum); \
318 } \
319 if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
320 bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<int64_t>(y)); \
321 return MakeValue(sum); \
322 } \
323 if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
324 bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<double>(y)); \
325 return MakeValue(sum); \
326 } \
327 if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
328 bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
329 return MakeValue(sum); \
330 } \
331 if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
332 bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int64_t>(y)); \
333 return MakeValue(sum); \
334 } \
335 MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
336 << ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
337 << ", value of y:" << y->ToString(); \
338 }
339
340 LOGIC_OP(Eq)
341 LOGIC_OP(Lt)
342 LOGIC_OP(Gt)
343 LOGIC_OP(Ne)
344 LOGIC_OP(Le)
345 LOGIC_OP(Ge)
346
347 ValuePtr ScalarUAdd(const ValuePtrList &list) {
348 if (list.size() != 1) {
349 MS_EXCEPTION(NotSupportError) << "Input number of ScalarUAdd should be 1, but got " << list.size();
350 }
351 ValuePtr x = list[0];
352 MS_EXCEPTION_IF_NULL(x);
353 return x;
354 }
355
ScalarUSub(const ValuePtrList & list)356 ValuePtr ScalarUSub(const ValuePtrList &list) {
357 if (list.size() != 1) {
358 MS_EXCEPTION(NotSupportError) << "Input number of ScalarUSub should be 1, but got " << list.size();
359 }
360 ValuePtr x = list[0];
361 MS_EXCEPTION_IF_NULL(x);
362
363 if (x->isa<Int32Imm>()) {
364 int32_t sum = -1 * GetValue<int32_t>(x);
365 return MakeValue(sum);
366 }
367 if (x->isa<Int64Imm>()) {
368 int64_t sum = -1 * GetValue<int64_t>(x);
369 return MakeValue(sum);
370 }
371 if (x->isa<FP32Imm>()) {
372 float sum = -1.0f * GetValue<float>(x);
373 return MakeValue(sum);
374 }
375
376 MS_EXCEPTION(NotSupportError) << "Not support ScalarUSub [x:" << x->ToString() << "].";
377 }
378
ScalarLog(const ValuePtrList & list)379 ValuePtr ScalarLog(const ValuePtrList &list) {
380 if (list.size() != 1) {
381 MS_EXCEPTION(NotSupportError) << "Input number of ScalarLog must be 1, but got " << list.size();
382 }
383 ValuePtr x = list[0];
384 MS_EXCEPTION_IF_NULL(x);
385
386 if (x->isa<FP64Imm>()) {
387 double v = log(GetValue<double>(x));
388 return MakeValue(v);
389 }
390 if (x->isa<FP32Imm>()) {
391 auto v = static_cast<float>(log(GetValue<float>(x)));
392 return MakeValue(v);
393 }
394
395 MS_EXCEPTION(NotSupportError) << "Not support ScalarLog [x:" << x->ToString() << "].";
396 }
397
BoolNot(const ValuePtrList & list)398 ValuePtr BoolNot(const ValuePtrList &list) {
399 if (list.size() != 1) {
400 MS_EXCEPTION(NotSupportError) << "Input number of BoolNot must be 1, but got " << list.size();
401 }
402 ValuePtr x = list[0];
403 MS_EXCEPTION_IF_NULL(x);
404 bool convert = false;
405
406 if (ValueToBool(x, &convert)) {
407 auto res = !convert;
408 return MakeValue(res);
409 }
410
411 MS_EXCEPTION(NotSupportError) << "Not support BoolNot [x:" << x->ToString() << "].";
412 }
413
BoolAnd(const ValuePtrList & list)414 ValuePtr BoolAnd(const ValuePtrList &list) {
415 if (list.size() != 2) {
416 MS_EXCEPTION(NotSupportError) << "Input number of BoolAnd must be 2, but got " << list.size();
417 }
418 ValuePtr x = list[0];
419 ValuePtr y = list[1];
420 MS_EXCEPTION_IF_NULL(x);
421 MS_EXCEPTION_IF_NULL(y);
422 bool x_b = false;
423 bool y_b = false;
424
425 if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
426 auto res = x_b && y_b;
427 return MakeValue(res);
428 }
429
430 MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolAnd [y:" << y->ToString();
431 }
432
BoolOr(const ValuePtrList & list)433 ValuePtr BoolOr(const ValuePtrList &list) {
434 if (list.size() != 2) {
435 MS_EXCEPTION(NotSupportError) << "Input number of BoolOr must be 2, but got " << list.size();
436 }
437 ValuePtr x = list[0];
438 ValuePtr y = list[1];
439 MS_EXCEPTION_IF_NULL(x);
440 MS_EXCEPTION_IF_NULL(y);
441 bool x_b = false;
442 bool y_b = false;
443
444 if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
445 auto res = x_b || y_b;
446 return MakeValue(res);
447 }
448
449 MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolOr [y:" << y->ToString() << "].";
450 }
451
BoolEq(const ValuePtrList & list)452 ValuePtr BoolEq(const ValuePtrList &list) {
453 if (list.size() != 2) {
454 MS_EXCEPTION(NotSupportError) << "Input number of BoolEq must be 2, but got " << list.size();
455 }
456 ValuePtr x = list[0];
457 ValuePtr y = list[1];
458 MS_EXCEPTION_IF_NULL(x);
459 MS_EXCEPTION_IF_NULL(y);
460 bool x_b = false;
461 bool y_b = false;
462
463 if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
464 auto res = x_b == y_b;
465 return MakeValue(res);
466 }
467
468 MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolEq [y:" << y->ToString() << "].";
469 }
470 } // namespace prim
471 } // namespace mindspore
472