1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "abstract/param_validator.h"
18 #include "abstract/infer_functions.h"
19 #include "abstract/abstract_function.h"
20 #include "abstract/utils.h"
21 #include "utils/symbolic.h"
22
23 namespace mindspore {
24 namespace abstract {
InferImplReturn(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)25 AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
26 const AbstractBasePtrList &args_spec_list) {
27 // Inputs: a pointer to an AbstractBase object
28 if (args_spec_list.size() != 1) {
29 MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? "
30 "while the input size is "
31 << args_spec_list.size() << ".";
32 }
33 AbstractBasePtr abs_base = args_spec_list[0];
34 return abs_base;
35 }
36
InferImplSwitch(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)37 AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
38 const AbstractBasePtrList &args_spec_list) {
39 // Inputs: condition, true branch, false branch
40 if (args_spec_list.size() != 3) {
41 MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
42 << ".";
43 }
44
45 auto cond = args_spec_list[0];
46 auto tb = args_spec_list[1];
47 auto fb = args_spec_list[2];
48 MS_EXCEPTION_IF_NULL(cond);
49
50 ValuePtr v = cond->GetValueTrack();
51 MS_EXCEPTION_IF_NULL(v);
52 // for tensor as condition, keeps both true and false branch.
53 if (v->isa<AnyValue>() || cond->isa<AbstractTensor>()) {
54 MS_EXCEPTION_IF_NULL(tb);
55 return tb->Join(fb);
56 }
57
58 if (v->isa<Scalar>()) {
59 if (v->cast<ScalarPtr>()->IsOne()) {
60 return tb;
61 } else {
62 return fb;
63 }
64 }
65
66 MS_LOG(EXCEPTION) << "Not support this condition value: " << cond->GetValueTrack()->ToString();
67 }
68
InferImplSwitchLayer(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)69 AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
70 const AbstractBasePtrList &args_spec_list) {
71 // Inputs: {index, MakeTuple{branch1,branch2,branch3....}}
72 constexpr auto kSwitchLayerInputNum = 2;
73 const std::string op_name = primitive->name();
74 abstract::CheckArgsSize(op_name, args_spec_list, kSwitchLayerInputNum);
75 auto index = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
76 auto &input_shape = index->shape()->shape();
77 if (!input_shape.empty()) {
78 MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size()
79 << " dimension tensor";
80 }
81 auto dtype = index->element()->BuildType();
82 if (dtype->type_id() != kInt32->type_id()) {
83 MS_EXCEPTION(ValueError) << op_name << " index must be a int32, but got " << dtype->ToString();
84 }
85
86 AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
87 AbstractBasePtrList branches = branches_abs->elements();
88 const size_t maximum_layer_num = 1000;
89 if (branches.empty() || branches.size() > maximum_layer_num) {
90 MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
91 << branches.size() << " branches.";
92 }
93
94 for (size_t i = 0; i < branches.size(); i++) {
95 MS_EXCEPTION_IF_NULL(branches[i]);
96 if (!branches[i]->isa<FuncGraphAbstractClosure>() && !branches[i]->isa<PartialAbstractClosure>()) {
97 MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
98 << branches[i]->ToString() << " as the " << i << "th element.";
99 }
100 }
101
102 auto b = branches[0];
103 // Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0]
104 // which will cancel the out of bound checking for index
105 if (branches.size() == 1) {
106 AbstractFuncAtomPtrList func_list{b->cast<AbstractFuncAtomPtr>()};
107 return std::make_shared<AbstractFuncUnion>(func_list);
108 }
109 for (size_t i = 1; i < branches.size(); i++) {
110 b = b->Join(branches[i]);
111 }
112 return b;
113 }
114
GetSupportedTargetValue()115 std::vector<ValuePtr> GetSupportedTargetValue() {
116 std::vector<ValuePtr> list = {kNone, MakeValue(false), MakeValue(true)};
117 return list;
118 }
119
SupportedIsTargetValue(const ValuePtr t)120 bool SupportedIsTargetValue(const ValuePtr t) {
121 auto list = GetSupportedTargetValue();
122 auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; });
123 return match;
124 }
125
InferImplIs_(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)126 AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
127 const AbstractBasePtrList &args_spec_list) {
128 // statement: x is t
129 // Inputs: x, t
130 const std::string op_name = primitive->name();
131 CheckArgsSize(op_name, args_spec_list, 2);
132 ValuePtr t = args_spec_list[1]->BuildValue();
133 if (!SupportedIsTargetValue(t)) {
134 MS_LOG(EXCEPTION) << "This comparator '" << t->ToString()
135 << "' is not supported. For statement 'is', only support compare with 'None', 'False' or 'True'";
136 }
137 ValuePtr x = args_spec_list[0]->BuildValue();
138
139 return std::make_shared<AbstractScalar>(*t == *x);
140 }
141
InferImplIsNot(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)142 AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
143 const AbstractBasePtrList &args_spec_list) {
144 // statement: x is not t
145 // Inputs: x, t
146 const std::string op_name = primitive->name();
147 CheckArgsSize(op_name, args_spec_list, 2);
148 ValuePtr t = args_spec_list[1]->BuildValue();
149 if (!SupportedIsTargetValue(t)) {
150 MS_LOG(EXCEPTION)
151 << "This comparator '" << t->ToString()
152 << "' is not supported. For statement 'is not' , only support compare with 'None', 'False' or 'True'";
153 }
154 ValuePtr x = args_spec_list[0]->BuildValue();
155
156 return std::make_shared<AbstractScalar>(!(*t == *x));
157 }
158
IsInDict(const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)159 bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
160 const std::string op_name = primitive->name();
161 CheckArgsSize(op_name, args_spec_list, 2);
162 auto key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
163 auto dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 1);
164
165 ValuePtr key_value = key->BuildValue();
166 if (!key_value->isa<StringImm>()) {
167 MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
168 }
169 auto key_str = GetValue<std::string>(key_value);
170 std::vector<AbstractAttribute> dict_elems = dict->elements();
171 auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
172 [key_str](const AbstractAttribute &item) { return item.first == key_str; });
173 return it != dict_elems.end();
174 }
175
InferImplInDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)176 AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
177 const AbstractBasePtrList &args_spec_list) {
178 // statement: x in t
179 // Inputs: x, t
180 return std::make_shared<AbstractScalar>(IsInDict(primitive, args_spec_list));
181 }
182
InferImplNotInDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)183 AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
184 const AbstractBasePtrList &args_spec_list) {
185 // statement: x not in t
186 // Inputs: x, t
187 return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
188 }
189
InferImplIsConstant(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)190 AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
191 const AbstractBasePtrList &args_spec_list) {
192 // statement: isconstant(x)
193 // Inputs: x
194 if (args_spec_list.size() != 1) {
195 MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1";
196 }
197 ValuePtr v = args_spec_list[0]->BuildValue();
198 return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
199 }
200 } // namespace abstract
201 } // namespace mindspore
202