1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/operator/prim_to_function.h"
18
19 #include "mindspore/core/ops/arithmetic_ops.h"
20 #include "mindspore/core/ops/comparison_ops.h"
21 #include "mindspore/core/ops/structure_ops.h"
22 namespace mindspore {
23 // namespace to support prim related definition
24 namespace prim {
25
PrimToFunction()26 PrimToFunction::PrimToFunction()
27 : prim_func_type_map_({{"bool_not", kPrimTypeNumOneArg},
28 {"scalar_cos", kPrimTypeNumOneArg},
29 {"scalar_exp", kPrimTypeNumOneArg},
30 {kScalarFloorOpName, kPrimTypeNumOneArg},
31 {"ScalarLog", kPrimTypeNumOneArg},
32 {"scalar_sin", kPrimTypeNumOneArg},
33 {"scalar_tan", kPrimTypeNumOneArg},
34 {kScalarTruncOpName, kPrimTypeNumOneArg},
35 {"typeof", kPrimTypeNumOneArg},
36 {"ScalarUadd", kPrimTypeNumOneArg},
37 {"ScalarUsub", kPrimTypeNumOneArg},
38 {"ScalarAdd", kPrimTypeNumTwoArgs},
39 {"bool_and", kPrimTypeNumTwoArgs},
40 {"bool_eq", kPrimTypeNumTwoArgs},
41 {"bool_or", kPrimTypeNumTwoArgs},
42 {"ScalarDiv", kPrimTypeNumTwoArgs},
43 {"ScalarEq", kPrimTypeNumTwoArgs},
44 {"ScalarGe", kPrimTypeNumTwoArgs},
45 {"ScalarGt", kPrimTypeNumTwoArgs},
46 {"ScalarLe", kPrimTypeNumTwoArgs},
47 {"ScalarLt", kPrimTypeNumTwoArgs},
48 {"scalar_ne", kPrimTypeNumTwoArgs},
49 {"ScalarMod", kPrimTypeNumTwoArgs},
50 {"ScalarMul", kPrimTypeNumTwoArgs},
51 {"ScalarPow", kPrimTypeNumTwoArgs},
52 {"ScalarSub", kPrimTypeNumTwoArgs},
53 {"ScalarFloorDiv", kPrimTypeNumTwoArgs},
54 {kScalarBitwiseAndOpName, kPrimTypeNumTwoArgs},
55 {kScalarBitwiseOrOpName, kPrimTypeNumTwoArgs},
56 {"bit_xor", kPrimTypeNumTwoArgs},
57 {"bit_left_shift", kPrimTypeNumTwoArgs},
58 {"bit_right_shift", kPrimTypeNumTwoArgs},
59 {kStringNotOpName, kPrimTypeStrOneArg},
60 {kStringConcatOpName, kPrimTypeStrTwoArgs},
61 {kStringInOpName, kPrimTypeStrTwoArgs},
62 {kStringEqOpName, kPrimTypeStrTwoArgs},
63 {kStringLtOpName, kPrimTypeStrTwoArgs},
64 {kStringGtOpName, kPrimTypeStrTwoArgs},
65 {kStringLeOpName, kPrimTypeStrTwoArgs},
66 {kStringGeOpName, kPrimTypeStrTwoArgs}}) {}
67
GetFunction(const PrimitivePtr & prim,FunctionPtr * func) const68 bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const {
69 if (func != nullptr) {
70 int64_t args_num = GetPrimType(prim);
71 switch (args_num) {
72 case kPrimTypeNumOneArg: {
73 std::vector<TypePtr> num_one_arg{std::make_shared<Number>()};
74 *func = Function(num_one_arg, std::make_shared<Number>()).DeepCopy()->cast<FunctionPtr>();
75 return true;
76 }
77 case kPrimTypeNumTwoArgs: {
78 std::vector<TypePtr> num_two_args{std::make_shared<Number>(), std::make_shared<Number>()};
79 *func = Function(num_two_args, std::make_shared<Number>()).DeepCopy()->cast<FunctionPtr>();
80 return true;
81 }
82 case kPrimTypeStrOneArg: {
83 std::vector<TypePtr> str_one_arg{std::make_shared<String>()};
84 *func = Function(str_one_arg, std::make_shared<String>()).DeepCopy()->cast<FunctionPtr>();
85 return true;
86 }
87 case kPrimTypeStrTwoArgs: {
88 std::vector<TypePtr> str_two_args{std::make_shared<String>(), std::make_shared<String>()};
89 *func = Function(str_two_args, std::make_shared<String>()).DeepCopy()->cast<FunctionPtr>();
90 return true;
91 }
92 default:
93 return false;
94 }
95 }
96 return false;
97 }
98
GetPrimType(const PrimitivePtr & prim) const99 int64_t PrimToFunction::GetPrimType(const PrimitivePtr &prim) const {
100 MS_EXCEPTION_IF_NULL(prim);
101 int64_t prim_type = static_cast<int64_t>(kPrimTypeUnknown);
102
103 auto value = prim_func_type_map_.find(prim->name());
104 if (value != prim_func_type_map_.end()) {
105 prim_type = value->second;
106 }
107 return prim_type;
108 }
109 } // namespace prim
110 } // namespace mindspore
111