1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "abstract/param_validator.h"
18 #include "abstract/ops/infer_functions.h"
19 #include "abstract/abstract_function.h"
20 #include "abstract/utils.h"
21 #include "utils/symbolic.h"
22
23 namespace mindspore {
24 namespace abstract {
InferImplReturn(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)25 AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
26 const AbstractBasePtrList &args_abs_list) {
27 // Inputs: a pointer to an AbstractBase object
28 if (args_abs_list.size() != 1) {
29 MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? "
30 "while the input size is "
31 << args_abs_list.size() << ".";
32 }
33 AbstractBasePtr abs_base = args_abs_list[0];
34 return abs_base;
35 }
36
CheckTensorCondValid(const AbstractBasePtr & cond)37 void CheckTensorCondValid(const AbstractBasePtr &cond) {
38 // Tensor condition must be one element or dynamic shape.
39 auto base_shape = cond->GetShape();
40 MS_EXCEPTION_IF_NULL(base_shape);
41 ShapeVector cond_shape = base_shape->cast<ShapePtr>()->shape();
42 if (cond_shape.empty()) {
43 return;
44 }
45 constexpr auto num_one = 1;
46 for (size_t i = 0; i < cond_shape.size(); i++) {
47 if (cond_shape[i] != num_one && cond_shape[i] != Shape::kShapeDimAny && cond_shape[i] != Shape::kShapeRankAny) {
48 MS_LOG(ERROR) << "The condition value of control flow can be a tensor with one element, "
49 << "but got tensor with shape " << base_shape->ToString();
50 MS_EXCEPTION(ValueError) << "The truth value of an array with more than one element is ambiguous.";
51 }
52 }
53 }
54
InferImplSwitch(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)55 AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
56 const AbstractBasePtrList &args_abs_list) {
57 // Inputs: condition, true branch, false branch
58 constexpr auto switch_input_size = 3;
59 if (args_abs_list.size() != switch_input_size) {
60 MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_abs_list.size()
61 << ".";
62 }
63
64 auto cond_abstract = args_abs_list[0];
65 auto true_branch = args_abs_list[1];
66 auto false_branch = args_abs_list[2];
67 MS_EXCEPTION_IF_NULL(cond_abstract);
68
69 ValuePtr cond_value = cond_abstract->GetValueTrack();
70 MS_EXCEPTION_IF_NULL(cond_value);
71 // If the value of condition is ValueAny or the abstract of condition is AbstractTensor,
72 // keeps both true and false branch.
73 if (cond_value->isa<ValueAny>() || cond_abstract->isa<AbstractTensor>()) {
74 if (cond_abstract->isa<AbstractTensor>()) {
75 CheckTensorCondValid(cond_abstract);
76 }
77 MS_EXCEPTION_IF_NULL(true_branch);
78 // Need record two func_graph
79 SetVariableFlag(true_branch);
80 SetVariableFlag(false_branch);
81 return true_branch->Join(false_branch);
82 }
83
84 if (cond_value->isa<Scalar>()) {
85 if (cond_value->cast<ScalarPtr>()->IsOne()) {
86 return true_branch;
87 } else {
88 return false_branch;
89 }
90 }
91 MS_LOG(EXCEPTION) << "Not support this condition value: " << cond_value->ToString();
92 }
93
InferImplSwitchLayer(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)94 AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
95 const AbstractBasePtrList &args_abs_list) {
96 // Inputs: {index, MakeTuple{branch1,branch2,branch3....}}
97 constexpr auto kSwitchLayerInputNum = 2;
98 const std::string op_name = primitive->name();
99 abstract::CheckArgsSize(op_name, args_abs_list, kSwitchLayerInputNum);
100 auto index = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
101 auto &input_shape = index->shape()->shape();
102 if (!input_shape.empty() && (input_shape.size() != 1 || input_shape[0] != 1)) {
103 MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size()
104 << " dimension tensor";
105 }
106 auto dtype = index->element()->BuildType();
107 if (dtype->type_id() != kInt32->type_id()) {
108 MS_EXCEPTION(ValueError) << op_name << " index must be an int32, but got " << dtype->ToString();
109 }
110
111 AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_abs_list, 1);
112 AbstractBasePtrList branches = branches_abs->elements();
113 const size_t maximum_layer_num = 1000;
114 if (branches.empty() || branches.size() > maximum_layer_num) {
115 MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
116 << branches.size() << " branches.";
117 }
118
119 auto b = branches[0];
120 SetVariableFlag(b);
121 // Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0]
122 // which will cancel the out of bound checking for index
123 if (branches.size() == 1) {
124 AbstractFuncAtomPtrList func_list{b->cast<AbstractFuncAtomPtr>()};
125 return std::make_shared<AbstractFuncUnion>(func_list);
126 }
127 for (size_t i = 1; i < branches.size(); i++) {
128 SetVariableFlag(branches[i]);
129 b = b->Join(branches[i]);
130 }
131 return b;
132 }
133
SupportedIsTargetValue(const TypePtr t)134 bool SupportedIsTargetValue(const TypePtr t) {
135 if (t->isa<TypeNone>() || t->isa<Int>() || t->isa<Bool>() || t->isa<String>() || t->isa<TypeType>()) {
136 return true;
137 }
138 return false;
139 }
140
CheckIfDataIsTarget(const std::string & op_name,const AbstractBasePtr & data_abs,const AbstractBasePtr & target_abs)141 std::pair<bool, bool> CheckIfDataIsTarget(const std::string &op_name, const AbstractBasePtr &data_abs,
142 const AbstractBasePtr &target_abs) {
143 MS_EXCEPTION_IF_NULL(target_abs);
144 // Check if data and target are both None.
145 if (data_abs->isa<AbstractNone>() || target_abs->isa<AbstractNone>()) {
146 return {data_abs->isa<AbstractNone>() && target_abs->isa<AbstractNone>(), false};
147 }
148 const auto &target_value = target_abs->BuildValue();
149 const auto &target_type = target_abs->BuildType();
150 MS_EXCEPTION_IF_NULL(target_value);
151 MS_EXCEPTION_IF_NULL(target_type);
152 const auto &data_value = data_abs->BuildValue();
153 MS_EXCEPTION_IF_NULL(data_value);
154 if (data_value != kValueAny && target_value != kValueAny && !SupportedIsTargetValue(target_type)) {
155 MS_LOG(EXCEPTION) << "For syntax like 'a " << op_name << " b', b supports Int, Bool, String, None and Type, "
156 << "but got " << target_value->ToString();
157 }
158 const auto &data_type = data_abs->BuildType();
159 MS_EXCEPTION_IF_NULL(data_type);
160 if (*data_type != *target_type) {
161 return {false, false};
162 }
163
164 if (data_value == kValueAny || target_value == kValueAny) {
165 return {false, true};
166 }
167 return {*data_value == *target_value, false};
168 }
169
InferImplIs_(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)170 AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
171 const AbstractBasePtrList &args_abs_list) {
172 // Statement: x is t
173 // Inputs: x, t
174 constexpr size_t kInputsNum = 2;
175 const std::string op_name = primitive->name();
176 CheckArgsSize(op_name, args_abs_list, kInputsNum);
177 constexpr size_t data_index = 0;
178 constexpr size_t target_index = 1;
179 auto res = CheckIfDataIsTarget("is", args_abs_list[data_index], args_abs_list[target_index]);
180 if (res.second) {
181 return std::make_shared<AbstractScalar>(kValueAny, kBool);
182 }
183 return std::make_shared<AbstractScalar>(res.first);
184 }
185
InferImplIsNot(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)186 AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
187 const AbstractBasePtrList &args_abs_list) {
188 // Statement: x is not t
189 // Inputs: x, t
190 constexpr size_t kInputsNum = 2;
191 const std::string op_name = primitive->name();
192 CheckArgsSize(op_name, args_abs_list, kInputsNum);
193 constexpr size_t data_index = 0;
194 constexpr size_t target_index = 1;
195 auto res = CheckIfDataIsTarget("is not", args_abs_list[data_index], args_abs_list[target_index]);
196 if (res.second) {
197 return std::make_shared<AbstractScalar>(kValueAny, kBool);
198 }
199 return std::make_shared<AbstractScalar>(!res.first);
200 }
201
IsInDict(const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)202 bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list) {
203 constexpr size_t kInputsNum = 2;
204 const std::string op_name = primitive->name();
205 CheckArgsSize(op_name, args_abs_list, kInputsNum);
206 const auto &key = args_abs_list[0];
207 auto dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 1);
208
209 ValuePtr key_value = key->BuildValue();
210 MS_EXCEPTION_IF_NULL(key_value);
211 std::vector<AbstractElementPair> dict_elems = dict->elements();
212 auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
213 return *key_value == *item.first->BuildValue();
214 });
215 return it != dict_elems.end();
216 }
217
InferImplInDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)218 AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
219 const AbstractBasePtrList &args_abs_list) {
220 // Statement: x in t
221 // Inputs: x, t
222 return std::make_shared<AbstractScalar>(IsInDict(primitive, args_abs_list));
223 }
224
InferImplNotInDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)225 AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
226 const AbstractBasePtrList &args_abs_list) {
227 // Statement: x not in t
228 // Inputs: x, t
229 return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_abs_list));
230 }
231
InferImplIsConstant(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)232 AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
233 const AbstractBasePtrList &args_abs_list) {
234 // Statement: isconstant(x)
235 // Inputs: x
236 if (args_abs_list.size() != 1) {
237 MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1";
238 }
239 ValuePtr v = args_abs_list[0]->BuildValue();
240 return std::make_shared<AbstractScalar>(!v->ContainsValueAny());
241 }
242 } // namespace abstract
243 } // namespace mindspore
244