• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/operator/cc_implementations.h"
18 #include <limits>
19 #include <algorithm>
20 #include <cmath>
21 #include <cfloat>
22 #include "utils/log_adapter.h"
23 #include "utils/convert_utils.h"
24 #include "utils/ms_utils.h"
25 
26 namespace mindspore {
27 // namespace to support primitive operators definition
28 namespace prim {
29 enum class DataType { kInt, kInt64, kFloat, kDouble, kUnknown };
30 
31 // Whether has a T type data in AnyPtrList.
32 template <class T>
HasType(const AnyPtrList & list)33 bool HasType(const AnyPtrList &list) {
34   bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
35   return ret;
36 }
37 
InferType(const AnyPtrList & list)38 DataType InferType(const AnyPtrList &list) {
39   if (HasType<double>(list)) {
40     return DataType::kDouble;
41   } else if (HasType<float>(list)) {
42     return DataType::kFloat;
43   } else if (HasType<int64_t>(list)) {
44     return DataType::kInt64;
45   } else if (HasType<int>(list)) {
46     return DataType::kInt;
47   }
48   return DataType::kUnknown;
49 }
50 
51 template <typename T>
IsAddOverflow(const T & x,const T & y,const T & max,const T & min)52 bool IsAddOverflow(const T &x, const T &y, const T &max, const T &min) {
53   return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
54 }
55 
56 template <typename T>
IsSubOverflow(const T & x,const T & y,const T & max,const T & min)57 bool IsSubOverflow(const T &x, const T &y, const T &max, const T &min) {
58   return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
59 }
60 
61 template <typename T>
IsMulOverflow(const T & x,const T & y,const T & max,const T & min)62 bool IsMulOverflow(const T &x, const T &y, const T &max, const T &min) {
63   return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || (x > 0 && y < 0 && (min / y) < x) ||
64          (x < 0 && y > 0 && (min / y) > x);
65 }
66 
67 template <typename T>
IsDivOverflow(const T & x,const T & y,const T & min)68 bool IsDivOverflow(const T &x, const T &y, const T &min) {
69   return (x == min && static_cast<int64_t>(y) == -1);
70 }
71 
72 enum class OpType { ADD, SUB, MUL, DIV, MOD };
73 
74 template <typename T>
IsSignedIntOverflow(T x,T y,OpType opType)75 bool IsSignedIntOverflow(T x, T y, OpType opType) {
76   auto max = std::numeric_limits<T>::max();
77   auto min = std::numeric_limits<T>::min();
78 
79   if (opType == OpType::ADD) {
80     return IsAddOverflow<T>(x, y, max, min);
81   }
82 
83   if (opType == OpType::SUB) {
84     return IsSubOverflow<T>(x, y, max, min);
85   }
86 
87   if (opType == OpType::MUL) {
88     return IsMulOverflow<T>(x, y, max, min);
89   }
90 
91   if (opType == OpType::DIV || opType == OpType::MOD) {
92     return IsDivOverflow<T>(x, y, min);
93   }
94 
95   MS_EXCEPTION(NotSupportError) << "Unsupported operation type.";
96 }
97 
98 template <typename T>
InnerScalarAdd(T x,T y)99 T InnerScalarAdd(T x, T y) {
100   if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::ADD)) {
101     MS_EXCEPTION(ValueError) << "Overflow of the sum of two signed number x: " << std::to_string(x)
102                              << ", y: " << std::to_string(y) << ".";
103   }
104   return x + y;
105 }
106 
107 template <typename T>
InnerScalarSub(T x,T y)108 T InnerScalarSub(T x, T y) {
109   if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::SUB)) {
110     MS_EXCEPTION(ValueError) << "Overflow of the sub of two signed number x: " << std::to_string(x)
111                              << ", y: " << std::to_string(y) << ".";
112   }
113   return x - y;
114 }
115 
116 template <typename T>
InnerScalarMul(T x,T y)117 T InnerScalarMul(T x, T y) {
118   if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MUL)) {
119     MS_EXCEPTION(ValueError) << "Overflow of the mul of two signed number x: " << std::to_string(x)
120                              << ", y: " << std::to_string(y) << ".";
121   }
122   return x * y;
123 }
124 
125 template <typename T>
InnerScalarDiv(T x,T y)126 float InnerScalarDiv(T x, T y) {
127   if (y == 0) {
128     MS_EXCEPTION(ValueError) << "Divisor could not be zero";
129   }
130   if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::DIV)) {
131     MS_EXCEPTION(ValueError) << "Overflow of the div of two signed number x: " << std::to_string(x)
132                              << ", y: " << std::to_string(y) << ".";
133   }
134   return static_cast<float>(x) / static_cast<float>(y);
135 }
136 
137 template <typename T>
InnerScalarFloordiv(T x,T y)138 T InnerScalarFloordiv(T x, T y) {
139   auto ret = std::floor(InnerScalarDiv(x, y));
140   if (std::is_integral<T>::value) {
141     return static_cast<int64_t>(ret);
142   }
143   return ret;
144 }
145 
146 template <typename T>
InnerScalarMod(T x,T y)147 T InnerScalarMod(T x, T y) {
148   if (y == 0) {
149     MS_EXCEPTION(ValueError) << "Could not mod to zero.";
150   }
151   if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MOD)) {
152     MS_EXCEPTION(ValueError) << "Overflow of the mod of two signed number x: " << std::to_string(x)
153                              << ", y: " << std::to_string(y) << ".";
154   }
155   if (std::is_integral<T>::value) {
156     return static_cast<int64_t>(x) % static_cast<int64_t>(y);
157   }
158   return x - y * std::floor(x / y);
159 }
160 
161 template <typename T, typename U>
InnerScalarPow(T x,U y)162 T InnerScalarPow(T x, U y) {
163   return std::pow(x, y);
164 }
165 
166 template <typename T, typename U>
InnerScalarEq(T x,U y)167 bool InnerScalarEq(T x, U y) {
168   double error = static_cast<double>(x) - static_cast<double>(y);
169   error = fabs(error);
170   return error < DBL_EPSILON;
171 }
172 
173 template <typename T, typename U>
InnerScalarLt(T x,U y)174 bool InnerScalarLt(T x, U y) {
175   return x < y;
176 }
177 
178 template <typename T, typename U>
InnerScalarGt(T x,U y)179 bool InnerScalarGt(T x, U y) {
180   return x > y;
181 }
182 
183 template <typename T, typename U>
InnerScalarNe(T x,U y)184 bool InnerScalarNe(T x, U y) {
185   return !InnerScalarEq(x, y);
186 }
187 
188 template <typename T, typename U>
InnerScalarLe(T x,U y)189 bool InnerScalarLe(T x, U y) {
190   return x <= y;
191 }
192 
193 template <typename T, typename U>
InnerScalarGe(T x,U y)194 bool InnerScalarGe(T x, U y) {
195   return x >= y;
196 }
197 
198 #define SCALAR_OP(op_t)                                                                                                \
199   ValuePtr Scalar##op_t(const ValuePtrList &list) {                                                                    \
200     do {                                                                                                               \
201       if (list.size() != 2) {                                                                                          \
202         MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be 2, but got " << list.size(); \
203       }                                                                                                                \
204       ValuePtr x = list[0];                                                                                            \
205       ValuePtr y = list[1];                                                                                            \
206       MS_EXCEPTION_IF_NULL(x);                                                                                         \
207       MS_EXCEPTION_IF_NULL(y);                                                                                         \
208       if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) {                                                                    \
209         double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y));                                      \
210         return MakeValue(sum);                                                                                         \
211       }                                                                                                                \
212       if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) {                                                                    \
213         float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y));                                         \
214         return MakeValue(sum);                                                                                         \
215       }                                                                                                                \
216       if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) {                                                                  \
217         int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y));                                               \
218         return MakeValue(sum);                                                                                         \
219       }                                                                                                                \
220       if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) {                                                                   \
221         float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y));                               \
222         return MakeValue(sum);                                                                                         \
223       }                                                                                                                \
224       if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) {                                                                   \
225         float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y)));                               \
226         return MakeValue(sum);                                                                                         \
227       }                                                                                                                \
228       if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) {                                                                  \
229         int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y));                                   \
230         return MakeValue(sum);                                                                                         \
231       }                                                                                                                \
232       if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) {                                                                   \
233         double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), GetValue<double>(y));                       \
234         return MakeValue(sum);                                                                                         \
235       }                                                                                                                \
236       if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) {                                                                   \
237         double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y)));         \
238         return MakeValue(sum);                                                                                         \
239       }                                                                                                                \
240       if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) {                                                                  \
241         int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y)));                            \
242         return MakeValue(sum);                                                                                         \
243       }                                                                                                                \
244       if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) {                                                                   \
245         double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y)));         \
246         return MakeValue(sum);                                                                                         \
247       }                                                                                                                \
248       if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) {                                                                   \
249         double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y)));                       \
250         return MakeValue(sum);                                                                                         \
251       }                                                                                                                \
252       if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) {                                                                  \
253         int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y));                            \
254         return MakeValue(sum);                                                                                         \
255       }                                                                                                                \
256       MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name()      \
257                               << ", value of x:" << x->ToString() << ", type of y:" << y->type_name()                  \
258                               << ", value of y:" << y->ToString();                                                     \
259     } while (0);                                                                                                       \
260   }
261 
262 SCALAR_OP(Add)
SCALAR_OP(Sub)263 SCALAR_OP(Sub)
264 SCALAR_OP(Mul)
265 SCALAR_OP(Div)
266 SCALAR_OP(Mod)
267 SCALAR_OP(Pow)
268 SCALAR_OP(Floordiv)
269 
270 #define LOGIC_OP(op_t)                                                                                               \
271   ValuePtr Scalar##op_t(const ValuePtrList &list) {                                                                  \
272     if (list.size() != 2) {                                                                                          \
273       MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be 2, but got " << list.size(); \
274     }                                                                                                                \
275     ValuePtr x = list[0];                                                                                            \
276     ValuePtr y = list[1];                                                                                            \
277     MS_EXCEPTION_IF_NULL(x);                                                                                         \
278     MS_EXCEPTION_IF_NULL(y);                                                                                         \
279     if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) {                                                                    \
280       bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y));                                        \
281       return MakeValue(sum);                                                                                         \
282     }                                                                                                                \
283     if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) {                                                                    \
284       bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y));                                          \
285       return MakeValue(sum);                                                                                         \
286     }                                                                                                                \
287     if (x->isa<FP64Imm>() && y->isa<FP32Imm>()) {                                                                    \
288       bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<float>(y));                                         \
289       return MakeValue(sum);                                                                                         \
290     }                                                                                                                \
291     if (x->isa<FP32Imm>() && y->isa<FP64Imm>()) {                                                                    \
292       bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<double>(y));                                         \
293       return MakeValue(sum);                                                                                         \
294     }                                                                                                                \
295     if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) {                                                                  \
296       bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y));                                              \
297       return MakeValue(sum);                                                                                         \
298     }                                                                                                                \
299     if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) {                                                                   \
300       bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y));                                            \
301       return MakeValue(sum);                                                                                         \
302     }                                                                                                                \
303     if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) {                                                                   \
304       bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int64_t>(y));                                        \
305       return MakeValue(sum);                                                                                         \
306     }                                                                                                                \
307     if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) {                                                                   \
308       bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y));                                            \
309       return MakeValue(sum);                                                                                         \
310     }                                                                                                                \
311     if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) {                                                                   \
312       bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<float>(y));                                        \
313       return MakeValue(sum);                                                                                         \
314     }                                                                                                                \
315     if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) {                                                                  \
316       bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y));                                      \
317       return MakeValue(sum);                                                                                         \
318     }                                                                                                                \
319     if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) {                                                                   \
320       bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<int64_t>(y));                                       \
321       return MakeValue(sum);                                                                                         \
322     }                                                                                                                \
323     if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) {                                                                   \
324       bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<double>(y));                                       \
325       return MakeValue(sum);                                                                                         \
326     }                                                                                                                \
327     if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) {                                                                  \
328       bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y));                                          \
329       return MakeValue(sum);                                                                                         \
330     }                                                                                                                \
331     if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) {                                                                  \
332       bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int64_t>(y));                                          \
333       return MakeValue(sum);                                                                                         \
334     }                                                                                                                \
335     MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name()      \
336                             << ", value of x:" << x->ToString() << ", type of y:" << y->type_name()                  \
337                             << ", value of y:" << y->ToString();                                                     \
338   }
339 
340 LOGIC_OP(Eq)
341 LOGIC_OP(Lt)
342 LOGIC_OP(Gt)
343 LOGIC_OP(Ne)
344 LOGIC_OP(Le)
345 LOGIC_OP(Ge)
346 
347 ValuePtr ScalarUAdd(const ValuePtrList &list) {
348   if (list.size() != 1) {
349     MS_EXCEPTION(NotSupportError) << "Input number of ScalarUAdd should be 1, but got " << list.size();
350   }
351   ValuePtr x = list[0];
352   MS_EXCEPTION_IF_NULL(x);
353   return x;
354 }
355 
ScalarUSub(const ValuePtrList & list)356 ValuePtr ScalarUSub(const ValuePtrList &list) {
357   if (list.size() != 1) {
358     MS_EXCEPTION(NotSupportError) << "Input number of ScalarUSub should be 1, but got " << list.size();
359   }
360   ValuePtr x = list[0];
361   MS_EXCEPTION_IF_NULL(x);
362 
363   if (x->isa<Int32Imm>()) {
364     int32_t sum = -1 * GetValue<int32_t>(x);
365     return MakeValue(sum);
366   }
367   if (x->isa<Int64Imm>()) {
368     int64_t sum = -1 * GetValue<int64_t>(x);
369     return MakeValue(sum);
370   }
371   if (x->isa<FP32Imm>()) {
372     float sum = -1.0f * GetValue<float>(x);
373     return MakeValue(sum);
374   }
375 
376   MS_EXCEPTION(NotSupportError) << "Not support ScalarUSub [x:" << x->ToString() << "].";
377 }
378 
ScalarLog(const ValuePtrList & list)379 ValuePtr ScalarLog(const ValuePtrList &list) {
380   if (list.size() != 1) {
381     MS_EXCEPTION(NotSupportError) << "Input number of ScalarLog must be 1, but got " << list.size();
382   }
383   ValuePtr x = list[0];
384   MS_EXCEPTION_IF_NULL(x);
385 
386   if (x->isa<FP64Imm>()) {
387     double v = log(GetValue<double>(x));
388     return MakeValue(v);
389   }
390   if (x->isa<FP32Imm>()) {
391     auto v = static_cast<float>(log(GetValue<float>(x)));
392     return MakeValue(v);
393   }
394 
395   MS_EXCEPTION(NotSupportError) << "Not support ScalarLog [x:" << x->ToString() << "].";
396 }
397 
BoolNot(const ValuePtrList & list)398 ValuePtr BoolNot(const ValuePtrList &list) {
399   if (list.size() != 1) {
400     MS_EXCEPTION(NotSupportError) << "Input number of BoolNot must be 1, but got " << list.size();
401   }
402   ValuePtr x = list[0];
403   MS_EXCEPTION_IF_NULL(x);
404   bool convert = false;
405 
406   if (ValueToBool(x, &convert)) {
407     auto res = !convert;
408     return MakeValue(res);
409   }
410 
411   MS_EXCEPTION(NotSupportError) << "Not support BoolNot [x:" << x->ToString() << "].";
412 }
413 
BoolAnd(const ValuePtrList & list)414 ValuePtr BoolAnd(const ValuePtrList &list) {
415   if (list.size() != 2) {
416     MS_EXCEPTION(NotSupportError) << "Input number of BoolAnd must be 2, but got " << list.size();
417   }
418   ValuePtr x = list[0];
419   ValuePtr y = list[1];
420   MS_EXCEPTION_IF_NULL(x);
421   MS_EXCEPTION_IF_NULL(y);
422   bool x_b = false;
423   bool y_b = false;
424 
425   if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
426     auto res = x_b && y_b;
427     return MakeValue(res);
428   }
429 
430   MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolAnd [y:" << y->ToString();
431 }
432 
BoolOr(const ValuePtrList & list)433 ValuePtr BoolOr(const ValuePtrList &list) {
434   if (list.size() != 2) {
435     MS_EXCEPTION(NotSupportError) << "Input number of BoolOr must be 2, but got " << list.size();
436   }
437   ValuePtr x = list[0];
438   ValuePtr y = list[1];
439   MS_EXCEPTION_IF_NULL(x);
440   MS_EXCEPTION_IF_NULL(y);
441   bool x_b = false;
442   bool y_b = false;
443 
444   if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
445     auto res = x_b || y_b;
446     return MakeValue(res);
447   }
448 
449   MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolOr [y:" << y->ToString() << "].";
450 }
451 
BoolEq(const ValuePtrList & list)452 ValuePtr BoolEq(const ValuePtrList &list) {
453   if (list.size() != 2) {
454     MS_EXCEPTION(NotSupportError) << "Input number of BoolEq must be 2, but got " << list.size();
455   }
456   ValuePtr x = list[0];
457   ValuePtr y = list[1];
458   MS_EXCEPTION_IF_NULL(x);
459   MS_EXCEPTION_IF_NULL(y);
460   bool x_b = false;
461   bool y_b = false;
462 
463   if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
464     auto res = x_b == y_b;
465     return MakeValue(res);
466   }
467 
468   MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolEq [y:" << y->ToString() << "].";
469 }
470 }  // namespace prim
471 }  // namespace mindspore
472