• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/operator/cc_implementations.h"
18 #include <limits>
19 #include <algorithm>
20 #include <cmath>
21 #include <string>
22 #include <cfloat>
23 #include <memory>
24 #include <type_traits>
25 
26 #include "utils/log_adapter.h"
27 #include "ir/scalar.h"
28 #include "ir/value.h"
29 #include "utils/convert_utils_base.h"
30 
31 namespace mindspore {
32 // namespace to support primitive operators definition
33 namespace prim {
34 enum class DataType { kInt, kInt64, kFloat, kDouble, kUnknown };
35 
36 // Whether has a T type data in AnyPtrList.
37 template <class T>
HasType(const AnyPtrList & list)38 bool HasType(const AnyPtrList &list) {
39   bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
40   return ret;
41 }
42 
InferType(const AnyPtrList & list)43 DataType InferType(const AnyPtrList &list) {
44   if (HasType<double>(list)) {
45     return DataType::kDouble;
46   } else if (HasType<float>(list)) {
47     return DataType::kFloat;
48   } else if (HasType<int64_t>(list)) {
49     return DataType::kInt64;
50   } else if (HasType<int>(list)) {
51     return DataType::kInt;
52   }
53   return DataType::kUnknown;
54 }
55 
56 template <typename T>
InnerScalarAdd(T x,T y)57 T InnerScalarAdd(T x, T y) {
58 #ifndef _MSC_VER
59   if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
60     T res;
61     if (__builtin_add_overflow(x, y, &res)) {
62       MS_EXCEPTION(ValueError) << "Overflow of the sum of two signed number x: " << std::to_string(x)
63                                << ", y: " << std::to_string(y) << ".";
64     }
65     return res;
66   }
67 #endif
68   return x + y;
69 }
70 
71 template <typename T>
InnerScalarSub(T x,T y)72 T InnerScalarSub(T x, T y) {
73 #ifndef _MSC_VER
74   if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
75     T res;
76     if (__builtin_sub_overflow(x, y, &res)) {
77       MS_EXCEPTION(ValueError) << "Overflow of the sub of two signed number x: " << std::to_string(x)
78                                << ", y: " << std::to_string(y) << ".";
79     }
80     return res;
81   }
82 #endif
83   return x - y;
84 }
85 
86 template <typename T>
InnerScalarMul(T x,T y)87 T InnerScalarMul(T x, T y) {
88 #ifndef _MSC_VER
89   if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
90     T res;
91     if (__builtin_mul_overflow(x, y, &res)) {
92       MS_EXCEPTION(ValueError) << "Overflow of the mul of two signed number x: " << std::to_string(x)
93                                << ", y: " << std::to_string(y) << ".";
94     }
95     return res;
96   }
97 #endif
98   return x * y;
99 }
100 
101 template <typename T>
InnerScalarDiv(T x,T y)102 float InnerScalarDiv(T x, T y) {
103   if (y == 0) {
104     MS_EXCEPTION(ValueError) << "The divisor could not be zero. But the divisor is zero now.";
105   }
106   if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
107     if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
108       MS_EXCEPTION(ValueError) << "Overflow of the div of two signed number x: " << std::to_string(x)
109                                << ", y: " << std::to_string(y) << ".";
110     }
111   }
112   return static_cast<float>(x) / static_cast<float>(y);
113 }
114 
115 template <typename T>
InnerScalarFloorDiv(T x,T y)116 T InnerScalarFloorDiv(T x, T y) {
117   auto ret = std::floor(InnerScalarDiv(x, y));
118   return static_cast<T>(ret);
119 }
120 
121 template <typename T>
InnerScalarMod(T x,T y)122 T InnerScalarMod(T x, T y) {
123   if (y == 0) {
124     MS_EXCEPTION(ValueError) << "Cannot perform modulo operation on zero.";
125   }
126   if constexpr (!std::is_integral<T>::value) {
127     return x - y * std::floor(x / y);
128   }
129   if constexpr (std::is_signed<T>::value) {
130     if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
131       MS_EXCEPTION(ValueError) << "Overflow of the mod of two signed number x: " << std::to_string(x)
132                                << ", y: " << std::to_string(y) << ".";
133     }
134   }
135   return static_cast<int64_t>(x) % static_cast<int64_t>(y);
136 }
137 
138 template <typename T, typename U>
InnerScalarPow(T x,U y)139 T InnerScalarPow(T x, U y) {
140   return std::pow(x, y);
141 }
142 
143 template <typename T, typename U>
InnerScalarEq(T x,U y)144 bool InnerScalarEq(T x, U y) {
145   if (std::isinf(static_cast<double>(x)) && std::isinf(static_cast<double>(y))) {
146     return (x > 0 && y > 0) || (x < 0 && y < 0);
147   }
148   double error = static_cast<double>(x) - static_cast<double>(y);
149   error = fabs(error);
150   return error < DBL_EPSILON;
151 }
152 
153 template <typename T, typename U>
InnerScalarLt(T x,U y)154 bool InnerScalarLt(T x, U y) {
155   return x < y;
156 }
157 
158 template <typename T, typename U>
InnerScalarGt(T x,U y)159 bool InnerScalarGt(T x, U y) {
160   return x > y;
161 }
162 
163 template <typename T, typename U>
InnerScalarNe(T x,U y)164 bool InnerScalarNe(T x, U y) {
165   return !InnerScalarEq(x, y);
166 }
167 
168 template <typename T, typename U>
InnerScalarLe(T x,U y)169 bool InnerScalarLe(T x, U y) {
170   return x <= y;
171 }
172 
173 template <typename T, typename U>
InnerScalarGe(T x,U y)174 bool InnerScalarGe(T x, U y) {
175   return x >= y;
176 }
177 
178 #define SCALAR_OP(op_t)                                                                                         \
179   ValuePtr Scalar##op_t(const ValuePtrList &list) {                                                             \
180     constexpr size_t scalar_input_size = 2;                                                                     \
181     if (list.size() != scalar_input_size) {                                                                     \
182       MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be " << scalar_input_size  \
183                                     << ", but got " << list.size();                                             \
184     }                                                                                                           \
185     const ValuePtr &x = list[0];                                                                                \
186     const ValuePtr &y = list[1];                                                                                \
187     MS_EXCEPTION_IF_NULL(x);                                                                                    \
188     MS_EXCEPTION_IF_NULL(y);                                                                                    \
189     if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) {                                                               \
190       float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y));                                    \
191       return MakeValue(sum);                                                                                    \
192     }                                                                                                           \
193     if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) {                                                             \
194       int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y));                                          \
195       return MakeValue(sum);                                                                                    \
196     }                                                                                                           \
197     if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) {                                                              \
198       float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y));                          \
199       return MakeValue(sum);                                                                                    \
200     }                                                                                                           \
201     if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) {                                                              \
202       float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y)));                          \
203       return MakeValue(sum);                                                                                    \
204     }                                                                                                           \
205     if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) {                                                             \
206       int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y));                              \
207       return MakeValue(sum);                                                                                    \
208     }                                                                                                           \
209     if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) {                                                              \
210       float sum = InnerScalar##op_t(LongToFloat(GetValue<int64_t>(x)), GetValue<float>(y));                     \
211       return MakeValue(sum);                                                                                    \
212     }                                                                                                           \
213     if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) {                                                             \
214       int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y)));                       \
215       return MakeValue(sum);                                                                                    \
216     }                                                                                                           \
217     if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) {                                                              \
218       float sum = InnerScalar##op_t(GetValue<float>(x), LongToFloat(GetValue<int64_t>(y)));                     \
219       return MakeValue(sum);                                                                                    \
220     }                                                                                                           \
221     if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) {                                                             \
222       int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y));                       \
223       return MakeValue(sum);                                                                                    \
224     }                                                                                                           \
225     if (x->isa<BoolImm>() && y->isa<BoolImm>()) {                                                               \
226       int sum = InnerScalar##op_t(static_cast<int>(GetValue<bool>(x)), static_cast<int>(GetValue<bool>(y)));    \
227       return MakeValue(sum);                                                                                    \
228     }                                                                                                           \
229     MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
230                             << ", value of x:" << x->ToString() << ", type of y:" << y->type_name()             \
231                             << ", value of y:" << y->ToString();                                                \
232   }
233 
234 SCALAR_OP(Add)
SCALAR_OP(Sub)235 SCALAR_OP(Sub)
236 SCALAR_OP(Mul)
237 SCALAR_OP(Div)
238 SCALAR_OP(Mod)
239 SCALAR_OP(Pow)
240 SCALAR_OP(FloorDiv)
241 
242 #define LOGIC_OP(op_t)                                                                                          \
243   ValuePtr Scalar##op_t(const ValuePtrList &list) {                                                             \
244     constexpr size_t scalar_input_size = 2;                                                                     \
245     if (list.size() != scalar_input_size) {                                                                     \
246       MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be " << scalar_input_size  \
247                                     << ", but got " << list.size();                                             \
248     }                                                                                                           \
249     const ValuePtr &x = list[0];                                                                                \
250     const ValuePtr &y = list[1];                                                                                \
251     MS_EXCEPTION_IF_NULL(x);                                                                                    \
252     MS_EXCEPTION_IF_NULL(y);                                                                                    \
253     if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) {                                                               \
254       bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y));                                     \
255       return MakeValue(sum);                                                                                    \
256     }                                                                                                           \
257     if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) {                                                             \
258       bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y));                                         \
259       return MakeValue(sum);                                                                                    \
260     }                                                                                                           \
261     if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) {                                                              \
262       bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y));                                       \
263       return MakeValue(sum);                                                                                    \
264     }                                                                                                           \
265     if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) {                                                              \
266       bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int64_t>(y));                                   \
267       return MakeValue(sum);                                                                                    \
268     }                                                                                                           \
269     if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) {                                                              \
270       bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y));                                       \
271       return MakeValue(sum);                                                                                    \
272     }                                                                                                           \
273     if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) {                                                              \
274       bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<float>(y));                                   \
275       return MakeValue(sum);                                                                                    \
276     }                                                                                                           \
277     if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) {                                                             \
278       bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y));                                 \
279       return MakeValue(sum);                                                                                    \
280     }                                                                                                           \
281     if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) {                                                             \
282       bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y));                                     \
283       return MakeValue(sum);                                                                                    \
284     }                                                                                                           \
285     if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) {                                                             \
286       bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int64_t>(y));                                     \
287       return MakeValue(sum);                                                                                    \
288     }                                                                                                           \
289     MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
290                             << ", value of x:" << x->ToString() << ", type of y:" << y->type_name()             \
291                             << ", value of y:" << y->ToString();                                                \
292   }
293 
294 LOGIC_OP(Eq)
295 LOGIC_OP(Lt)
296 LOGIC_OP(Gt)
297 LOGIC_OP(Ne)
298 LOGIC_OP(Le)
299 LOGIC_OP(Ge)
300 
301 template <typename T>
302 T InnerBitAnd(T x, T y) {
303   return x & y;
304 }
305 
306 template <typename T>
InnerBitOr(T x,T y)307 T InnerBitOr(T x, T y) {
308   return x | y;
309 }
310 
311 template <typename T>
InnerBitXor(T x,T y)312 T InnerBitXor(T x, T y) {
313   return x ^ y;
314 }
315 
316 template <typename T>
InnerBitLeftShift(T x,T y)317 T InnerBitLeftShift(T x, T y) {
318   if (y < 0) {
319     MS_EXCEPTION(ValueError) << "For shift operator, shift count must be a non-negative integer.";
320   }
321 #ifndef _MSC_VER
322   if (x == 0) {
323     return x;
324   }
325   if (x < 0) {
326     if (x == -1) {
327       constexpr T max_bit_count = 64;
328       if (y == max_bit_count - 1) {
329         return std::numeric_limits<T>::min();
330       }
331     }
332     if (x == std::numeric_limits<T>::min() || static_cast<T>(__builtin_clzll(static_cast<uint64_t>(-x))) <= y) {
333       MS_EXCEPTION(RuntimeError) << "Arithmetic left shift causes int64 integer overflow.";
334     }
335   } else if (static_cast<T>(__builtin_clzll(static_cast<uint64_t>(x))) <= y) {
336     MS_EXCEPTION(RuntimeError) << "Arithmetic left shift causes int64 integer overflow.";
337   }
338 #endif
339   return x << y;
340 }
341 
342 template <typename T>
InnerBitRightShift(T x,T y)343 T InnerBitRightShift(T x, T y) {
344   if (y < 0) {
345     MS_EXCEPTION(ValueError) << "For shift operator, shift count must be a non-negative integer.";
346   }
347   return x >> y;
348 }
349 
350 #define BIT_OP(op_t)                                                                                      \
351   ValuePtr Bit##op_t(const ValuePtrList &list) {                                                          \
352     constexpr size_t bit_input_size = 2;                                                                  \
353     if (list.size() != bit_input_size) {                                                                  \
354       MS_EXCEPTION(NotSupportError) << "Input number of Bit" << #op_t << " should be" << bit_input_size   \
355                                     << ", but got " << list.size();                                       \
356     }                                                                                                     \
357     const ValuePtr &x = list[0];                                                                          \
358     const ValuePtr &y = list[1];                                                                          \
359     MS_EXCEPTION_IF_NULL(x);                                                                              \
360     MS_EXCEPTION_IF_NULL(y);                                                                              \
361     if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) {                                                       \
362       int32_t res = InnerBit##op_t(IntToLong(GetValue<int>(x)), IntToLong(GetValue<int>(y)));             \
363       return MakeValue(res);                                                                              \
364     }                                                                                                     \
365     if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) {                                                       \
366       int64_t res = InnerBit##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y)));                    \
367       return MakeValue(res);                                                                              \
368     }                                                                                                     \
369     if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) {                                                       \
370       int64_t res = InnerBit##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y));                    \
371       return MakeValue(res);                                                                              \
372     }                                                                                                     \
373     if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) {                                                       \
374       int64_t res = InnerBit##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y));                           \
375       return MakeValue(res);                                                                              \
376     }                                                                                                     \
377     MS_EXCEPTION(TypeError) << "Unsupported input type. For Bit" << #op_t                                 \
378                             << ", only integer types are supported, but got type of x:" << x->type_name() \
379                             << ", value of x:" << x->ToString() << ", type of y:" << y->type_name()       \
380                             << ", value of y:" << y->ToString();                                          \
381   }
382 
383 BIT_OP(And)
BIT_OP(Or)384 BIT_OP(Or)
385 BIT_OP(Xor)
386 BIT_OP(LeftShift)
387 BIT_OP(RightShift)
388 
389 ValuePtr ScalarUAdd(const ValuePtrList &list) {
390   constexpr size_t scalar_input_size = 1;
391   if (list.size() != scalar_input_size) {
392     MS_EXCEPTION(NotSupportError) << "Input number of ScalarUAdd should be " << scalar_input_size << ", but got "
393                                   << list.size() << ".";
394   }
395   const auto &x = list[0];
396   MS_EXCEPTION_IF_NULL(x);
397   return x;
398 }
399 
ScalarUSub(const ValuePtrList & list)400 ValuePtr ScalarUSub(const ValuePtrList &list) {
401   constexpr size_t scalar_input_size = 1;
402   if (list.size() != scalar_input_size) {
403     MS_EXCEPTION(NotSupportError) << "Input number of ScalarUSub should be " << scalar_input_size << ", but got "
404                                   << list.size() << ".";
405   }
406   const auto &x = list[0];
407   MS_EXCEPTION_IF_NULL(x);
408 
409   if (x->isa<Int32Imm>()) {
410     int32_t sum = -1 * GetValue<int32_t>(x);
411     return MakeValue(sum);
412   }
413   if (x->isa<Int64Imm>()) {
414     int64_t sum = -1 * GetValue<int64_t>(x);
415     return MakeValue(sum);
416   }
417   if (x->isa<FP32Imm>()) {
418     float sum = -1.0f * GetValue<float>(x);
419     return MakeValue(sum);
420   }
421   MS_EXCEPTION(NotSupportError) << "Not support ScalarUSub [x:" << x->ToString() << "].";
422 }
423 
ScalarLog(const ValuePtrList & list)424 ValuePtr ScalarLog(const ValuePtrList &list) {
425   constexpr size_t scalar_input_size = 1;
426   if (list.size() != scalar_input_size) {
427     MS_EXCEPTION(NotSupportError) << "Input number of ScalarLog must be " << scalar_input_size << ", but got "
428                                   << list.size() << ".";
429   }
430   const auto &x = list[0];
431   MS_EXCEPTION_IF_NULL(x);
432 
433   if (x->isa<FP32Imm>()) {
434     auto v = static_cast<float>(log(GetValue<float>(x)));
435     return MakeValue(v);
436   }
437   MS_EXCEPTION(NotSupportError) << "Not support ScalarLog [x:" << x->ToString() << "].";
438 }
439 
GetBooleansFromValueList(const std::string & prim_name,const ValuePtrList & list,bool * val_x,bool * val_y)440 void GetBooleansFromValueList(const std::string &prim_name, const ValuePtrList &list, bool *val_x, bool *val_y) {
441   constexpr size_t boolean_input_size = 2;
442   if (list.size() != boolean_input_size) {
443     MS_EXCEPTION(NotSupportError) << "The input number of " << prim_name << " operator must be " << boolean_input_size
444                                   << ", but got " << list.size() << ".";
445   }
446   const auto &x = list[0];
447   const auto &y = list[1];
448   MS_EXCEPTION_IF_NULL(x);
449   MS_EXCEPTION_IF_NULL(y);
450   if (!x->isa<BoolImm>() || !y->isa<BoolImm>()) {
451     MS_LOG(EXCEPTION) << "The inputs of " << prim_name
452                       << " operator should be two booleans, but got param0: " << x->ToString()
453                       << ", param1: " << y->ToString() << ".";
454   }
455   *val_x = x->cast<BoolImmPtr>()->value();
456   *val_y = y->cast<BoolImmPtr>()->value();
457 }
458 
GetStringsFromValueList(const std::string & prim_name,const ValuePtrList & list,std::string * str_x,std::string * str_y)459 void GetStringsFromValueList(const std::string &prim_name, const ValuePtrList &list, std::string *str_x,
460                              std::string *str_y) {
461   constexpr size_t string_input_size = 2;
462   if (list.size() != string_input_size) {
463     MS_EXCEPTION(NotSupportError) << "The input number of " << prim_name << " operator must be " << string_input_size
464                                   << ", but got " << list.size() << ".";
465   }
466   const auto &x = list[0];
467   const auto &y = list[1];
468   MS_EXCEPTION_IF_NULL(x);
469   MS_EXCEPTION_IF_NULL(y);
470   if (!x->isa<StringImm>() || !y->isa<StringImm>()) {
471     MS_LOG(EXCEPTION) << "The inputs of " << prim_name
472                       << " operator should be two strings, but got param0: " << x->ToString()
473                       << ", param1: " << y->ToString() << ".";
474   }
475   *str_x = GetValue<std::string>(x);
476   *str_y = GetValue<std::string>(y);
477 }
478 
BoolNot(const ValuePtrList & list)479 ValuePtr BoolNot(const ValuePtrList &list) {
480   constexpr size_t boolean_input_size = 1;
481   if (list.size() != boolean_input_size) {
482     MS_EXCEPTION(NotSupportError) << "Input number of BoolNot must be " << boolean_input_size << ", but got "
483                                   << list.size() << ".";
484   }
485   const auto &x = list[0];
486   MS_EXCEPTION_IF_NULL(x);
487   if (!x->isa<BoolImm>()) {
488     MS_LOG(EXCEPTION) << "The input of BoolNot operator should be a boolean, but got " << x->ToString() << ".";
489   }
490   bool val = x->cast<BoolImmPtr>()->value();
491   return MakeValue<bool>(!val);
492 }
493 
StringNot(const ValuePtrList & list)494 ValuePtr StringNot(const ValuePtrList &list) {
495   constexpr size_t string_input_size = 1;
496   if (list.size() != string_input_size) {
497     MS_EXCEPTION(NotSupportError) << "Input number of StringNot must be " << string_input_size << ", but got "
498                                   << list.size() << ".";
499   }
500   const auto &x = list[0];
501   MS_EXCEPTION_IF_NULL(x);
502   if (!x->isa<StringImm>()) {
503     MS_LOG(EXCEPTION) << "The input of BoolNot operator should be a string, but got " << x->ToString() << ".";
504   }
505   std::string str = x->cast<StringImmPtr>()->value();
506   return MakeValue<bool>(str.empty());
507 }
508 
BoolAnd(const ValuePtrList & list)509 ValuePtr BoolAnd(const ValuePtrList &list) {
510   bool x = false;
511   bool y = false;
512   GetBooleansFromValueList("BoolAnd", list, &x, &y);
513   return MakeValue<bool>(x && y);
514 }
515 
BoolOr(const ValuePtrList & list)516 ValuePtr BoolOr(const ValuePtrList &list) {
517   bool x = false;
518   bool y = false;
519   GetBooleansFromValueList("BoolOr", list, &x, &y);
520   return MakeValue<bool>(x || y);
521 }
522 
BoolEq(const ValuePtrList & list)523 ValuePtr BoolEq(const ValuePtrList &list) {
524   bool x = false;
525   bool y = false;
526   GetBooleansFromValueList("BoolEq", list, &x, &y);
527   return MakeValue<bool>(x == y);
528 }
529 
StringEq(const ValuePtrList & list)530 ValuePtr StringEq(const ValuePtrList &list) {
531   std::string str_x;
532   std::string str_y;
533   GetStringsFromValueList("StringEq", list, &str_x, &str_y);
534   return MakeValue<bool>(str_x == str_y);
535 }
536 
StringLt(const ValuePtrList & list)537 ValuePtr StringLt(const ValuePtrList &list) {
538   std::string str_x;
539   std::string str_y;
540   GetStringsFromValueList("StringLt", list, &str_x, &str_y);
541   return MakeValue<bool>(str_x < str_y);
542 }
543 
StringGt(const ValuePtrList & list)544 ValuePtr StringGt(const ValuePtrList &list) {
545   std::string str_x;
546   std::string str_y;
547   GetStringsFromValueList("StringGt", list, &str_x, &str_y);
548   return MakeValue<bool>(str_x > str_y);
549 }
550 
StringLe(const ValuePtrList & list)551 ValuePtr StringLe(const ValuePtrList &list) {
552   std::string str_x;
553   std::string str_y;
554   GetStringsFromValueList("StringLe", list, &str_x, &str_y);
555   return MakeValue<bool>(str_x <= str_y);
556 }
557 
StringGe(const ValuePtrList & list)558 ValuePtr StringGe(const ValuePtrList &list) {
559   std::string str_x;
560   std::string str_y;
561   GetStringsFromValueList("StringGe", list, &str_x, &str_y);
562   return MakeValue<bool>(str_x >= str_y);
563 }
564 
StringIn(const ValuePtrList & list)565 ValuePtr StringIn(const ValuePtrList &list) {
566   std::string str_x;
567   std::string str_y;
568   GetStringsFromValueList("StringIn", list, &str_x, &str_y);
569   return MakeValue<bool>(str_y.find(str_x) != std::string::npos);
570 }
571 
StringConcat(const ValuePtrList & list)572 ValuePtr StringConcat(const ValuePtrList &list) {
573   std::string str_x;
574   std::string str_y;
575   GetStringsFromValueList("StringConcat", list, &str_x, &str_y);
576   return MakeValue<std::string>(str_x + str_y);
577 }
578 }  // namespace prim
579 }  // namespace mindspore
580