• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/operator/prim_to_function.h"
18 
19 #include "mindspore/core/ops/arithmetic_ops.h"
20 #include "mindspore/core/ops/comparison_ops.h"
21 #include "mindspore/core/ops/structure_ops.h"
22 namespace mindspore {
23 // namespace to support prim related definition
24 namespace prim {
25 
PrimToFunction()26 PrimToFunction::PrimToFunction()
27     : prim_func_type_map_({{"bool_not", kPrimTypeNumOneArg},
28                            {"scalar_cos", kPrimTypeNumOneArg},
29                            {"scalar_exp", kPrimTypeNumOneArg},
30                            {kScalarFloorOpName, kPrimTypeNumOneArg},
31                            {"ScalarLog", kPrimTypeNumOneArg},
32                            {"scalar_sin", kPrimTypeNumOneArg},
33                            {"scalar_tan", kPrimTypeNumOneArg},
34                            {kScalarTruncOpName, kPrimTypeNumOneArg},
35                            {"typeof", kPrimTypeNumOneArg},
36                            {"ScalarUadd", kPrimTypeNumOneArg},
37                            {"ScalarUsub", kPrimTypeNumOneArg},
38                            {"ScalarAdd", kPrimTypeNumTwoArgs},
39                            {"bool_and", kPrimTypeNumTwoArgs},
40                            {"bool_eq", kPrimTypeNumTwoArgs},
41                            {"bool_or", kPrimTypeNumTwoArgs},
42                            {"ScalarDiv", kPrimTypeNumTwoArgs},
43                            {"ScalarEq", kPrimTypeNumTwoArgs},
44                            {"ScalarGe", kPrimTypeNumTwoArgs},
45                            {"ScalarGt", kPrimTypeNumTwoArgs},
46                            {"ScalarLe", kPrimTypeNumTwoArgs},
47                            {"ScalarLt", kPrimTypeNumTwoArgs},
48                            {"scalar_ne", kPrimTypeNumTwoArgs},
49                            {"ScalarMod", kPrimTypeNumTwoArgs},
50                            {"ScalarMul", kPrimTypeNumTwoArgs},
51                            {"ScalarPow", kPrimTypeNumTwoArgs},
52                            {"ScalarSub", kPrimTypeNumTwoArgs},
53                            {"ScalarFloorDiv", kPrimTypeNumTwoArgs},
54                            {kScalarBitwiseAndOpName, kPrimTypeNumTwoArgs},
55                            {kScalarBitwiseOrOpName, kPrimTypeNumTwoArgs},
56                            {"bit_xor", kPrimTypeNumTwoArgs},
57                            {"bit_left_shift", kPrimTypeNumTwoArgs},
58                            {"bit_right_shift", kPrimTypeNumTwoArgs},
59                            {kStringNotOpName, kPrimTypeStrOneArg},
60                            {kStringConcatOpName, kPrimTypeStrTwoArgs},
61                            {kStringInOpName, kPrimTypeStrTwoArgs},
62                            {kStringEqOpName, kPrimTypeStrTwoArgs},
63                            {kStringLtOpName, kPrimTypeStrTwoArgs},
64                            {kStringGtOpName, kPrimTypeStrTwoArgs},
65                            {kStringLeOpName, kPrimTypeStrTwoArgs},
66                            {kStringGeOpName, kPrimTypeStrTwoArgs}}) {}
67 
GetFunction(const PrimitivePtr & prim,FunctionPtr * func) const68 bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const {
69   if (func != nullptr) {
70     int64_t args_num = GetPrimType(prim);
71     switch (args_num) {
72       case kPrimTypeNumOneArg: {
73         std::vector<TypePtr> num_one_arg{std::make_shared<Number>()};
74         *func = Function(num_one_arg, std::make_shared<Number>()).DeepCopy()->cast<FunctionPtr>();
75         return true;
76       }
77       case kPrimTypeNumTwoArgs: {
78         std::vector<TypePtr> num_two_args{std::make_shared<Number>(), std::make_shared<Number>()};
79         *func = Function(num_two_args, std::make_shared<Number>()).DeepCopy()->cast<FunctionPtr>();
80         return true;
81       }
82       case kPrimTypeStrOneArg: {
83         std::vector<TypePtr> str_one_arg{std::make_shared<String>()};
84         *func = Function(str_one_arg, std::make_shared<String>()).DeepCopy()->cast<FunctionPtr>();
85         return true;
86       }
87       case kPrimTypeStrTwoArgs: {
88         std::vector<TypePtr> str_two_args{std::make_shared<String>(), std::make_shared<String>()};
89         *func = Function(str_two_args, std::make_shared<String>()).DeepCopy()->cast<FunctionPtr>();
90         return true;
91       }
92       default:
93         return false;
94     }
95   }
96   return false;
97 }
98 
GetPrimType(const PrimitivePtr & prim) const99 int64_t PrimToFunction::GetPrimType(const PrimitivePtr &prim) const {
100   MS_EXCEPTION_IF_NULL(prim);
101   int64_t prim_type = static_cast<int64_t>(kPrimTypeUnknown);
102 
103   auto value = prim_func_type_map_.find(prim->name());
104   if (value != prim_func_type_map_.end()) {
105     prim_type = value->second;
106   }
107   return prim_type;
108 }
109 }  // namespace prim
110 }  // namespace mindspore
111