• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "abstract/param_validator.h"
18 #include "abstract/ops/infer_functions.h"
19 #include "abstract/abstract_function.h"
20 #include "abstract/utils.h"
21 #include "utils/symbolic.h"
22 
23 namespace mindspore {
24 namespace abstract {
InferImplReturn(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)25 AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
26                                 const AbstractBasePtrList &args_abs_list) {
27   // Inputs: a pointer to an AbstractBase object
28   if (args_abs_list.size() != 1) {
29     MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? "
30                     "while the input size is "
31                  << args_abs_list.size() << ".";
32   }
33   AbstractBasePtr abs_base = args_abs_list[0];
34   return abs_base;
35 }
36 
CheckTensorCondValid(const AbstractBasePtr & cond)37 void CheckTensorCondValid(const AbstractBasePtr &cond) {
38   // Tensor condition must be one element or dynamic shape.
39   auto base_shape = cond->GetShape();
40   MS_EXCEPTION_IF_NULL(base_shape);
41   ShapeVector cond_shape = base_shape->cast<ShapePtr>()->shape();
42   if (cond_shape.empty()) {
43     return;
44   }
45   constexpr auto num_one = 1;
46   for (size_t i = 0; i < cond_shape.size(); i++) {
47     if (cond_shape[i] != num_one && cond_shape[i] != Shape::kShapeDimAny && cond_shape[i] != Shape::kShapeRankAny) {
48       MS_LOG(ERROR) << "The condition value of control flow can be a tensor with one element, "
49                     << "but got tensor with shape " << base_shape->ToString();
50       MS_EXCEPTION(ValueError) << "The truth value of an array with more than one element is ambiguous.";
51     }
52   }
53 }
54 
InferImplSwitch(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)55 AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
56                                 const AbstractBasePtrList &args_abs_list) {
57   // Inputs: condition, true branch, false branch
58   constexpr auto switch_input_size = 3;
59   if (args_abs_list.size() != switch_input_size) {
60     MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_abs_list.size()
61                       << ".";
62   }
63 
64   auto cond_abstract = args_abs_list[0];
65   auto true_branch = args_abs_list[1];
66   auto false_branch = args_abs_list[2];
67   MS_EXCEPTION_IF_NULL(cond_abstract);
68 
69   ValuePtr cond_value = cond_abstract->GetValueTrack();
70   MS_EXCEPTION_IF_NULL(cond_value);
71   // If the value of condition is ValueAny or the abstract of condition is AbstractTensor,
72   // keeps both true and false branch.
73   if (cond_value->isa<ValueAny>() || cond_abstract->isa<AbstractTensor>()) {
74     if (cond_abstract->isa<AbstractTensor>()) {
75       CheckTensorCondValid(cond_abstract);
76     }
77     MS_EXCEPTION_IF_NULL(true_branch);
78     // Need record two func_graph
79     SetVariableFlag(true_branch);
80     SetVariableFlag(false_branch);
81     return true_branch->Join(false_branch);
82   }
83 
84   if (cond_value->isa<Scalar>()) {
85     if (cond_value->cast<ScalarPtr>()->IsOne()) {
86       return true_branch;
87     } else {
88       return false_branch;
89     }
90   }
91   MS_LOG(EXCEPTION) << "Not support this condition value: " << cond_value->ToString();
92 }
93 
InferImplSwitchLayer(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)94 AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
95                                      const AbstractBasePtrList &args_abs_list) {
96   // Inputs: {index, MakeTuple{branch1,branch2,branch3....}}
97   constexpr auto kSwitchLayerInputNum = 2;
98   const std::string op_name = primitive->name();
99   abstract::CheckArgsSize(op_name, args_abs_list, kSwitchLayerInputNum);
100   auto index = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
101   auto &input_shape = index->shape()->shape();
102   if (!input_shape.empty() && (input_shape.size() != 1 || input_shape[0] != 1)) {
103     MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size()
104                              << " dimension tensor";
105   }
106   auto dtype = index->element()->BuildType();
107   if (dtype->type_id() != kInt32->type_id()) {
108     MS_EXCEPTION(ValueError) << op_name << " index must be an int32, but got " << dtype->ToString();
109   }
110 
111   AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_abs_list, 1);
112   AbstractBasePtrList branches = branches_abs->elements();
113   const size_t maximum_layer_num = 1000;
114   if (branches.empty() || branches.size() > maximum_layer_num) {
115     MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
116                              << branches.size() << " branches.";
117   }
118 
119   auto b = branches[0];
120   SetVariableFlag(b);
121   // Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0]
122   // which will cancel the out of bound checking for index
123   if (branches.size() == 1) {
124     AbstractFuncAtomPtrList func_list{b->cast<AbstractFuncAtomPtr>()};
125     return std::make_shared<AbstractFuncUnion>(func_list);
126   }
127   for (size_t i = 1; i < branches.size(); i++) {
128     SetVariableFlag(branches[i]);
129     b = b->Join(branches[i]);
130   }
131   return b;
132 }
133 
SupportedIsTargetValue(const TypePtr t)134 bool SupportedIsTargetValue(const TypePtr t) {
135   if (t->isa<TypeNone>() || t->isa<Int>() || t->isa<Bool>() || t->isa<String>() || t->isa<TypeType>()) {
136     return true;
137   }
138   return false;
139 }
140 
CheckIfDataIsTarget(const std::string & op_name,const AbstractBasePtr & data_abs,const AbstractBasePtr & target_abs)141 std::pair<bool, bool> CheckIfDataIsTarget(const std::string &op_name, const AbstractBasePtr &data_abs,
142                                           const AbstractBasePtr &target_abs) {
143   MS_EXCEPTION_IF_NULL(target_abs);
144   // Check if data and target are both None.
145   if (data_abs->isa<AbstractNone>() || target_abs->isa<AbstractNone>()) {
146     return {data_abs->isa<AbstractNone>() && target_abs->isa<AbstractNone>(), false};
147   }
148   const auto &target_value = target_abs->BuildValue();
149   const auto &target_type = target_abs->BuildType();
150   MS_EXCEPTION_IF_NULL(target_value);
151   MS_EXCEPTION_IF_NULL(target_type);
152   const auto &data_value = data_abs->BuildValue();
153   MS_EXCEPTION_IF_NULL(data_value);
154   if (data_value != kValueAny && target_value != kValueAny && !SupportedIsTargetValue(target_type)) {
155     MS_LOG(EXCEPTION) << "For syntax like 'a " << op_name << " b', b supports Int, Bool, String, None and Type, "
156                       << "but got " << target_value->ToString();
157   }
158   const auto &data_type = data_abs->BuildType();
159   MS_EXCEPTION_IF_NULL(data_type);
160   if (*data_type != *target_type) {
161     return {false, false};
162   }
163 
164   if (data_value == kValueAny || target_value == kValueAny) {
165     return {false, true};
166   }
167   return {*data_value == *target_value, false};
168 }
169 
InferImplIs_(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)170 AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
171                              const AbstractBasePtrList &args_abs_list) {
172   // Statement: x is t
173   // Inputs: x, t
174   constexpr size_t kInputsNum = 2;
175   const std::string op_name = primitive->name();
176   CheckArgsSize(op_name, args_abs_list, kInputsNum);
177   constexpr size_t data_index = 0;
178   constexpr size_t target_index = 1;
179   auto res = CheckIfDataIsTarget("is", args_abs_list[data_index], args_abs_list[target_index]);
180   if (res.second) {
181     return std::make_shared<AbstractScalar>(kValueAny, kBool);
182   }
183   return std::make_shared<AbstractScalar>(res.first);
184 }
185 
InferImplIsNot(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)186 AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
187                                const AbstractBasePtrList &args_abs_list) {
188   // Statement: x is not t
189   // Inputs: x, t
190   constexpr size_t kInputsNum = 2;
191   const std::string op_name = primitive->name();
192   CheckArgsSize(op_name, args_abs_list, kInputsNum);
193   constexpr size_t data_index = 0;
194   constexpr size_t target_index = 1;
195   auto res = CheckIfDataIsTarget("is not", args_abs_list[data_index], args_abs_list[target_index]);
196   if (res.second) {
197     return std::make_shared<AbstractScalar>(kValueAny, kBool);
198   }
199   return std::make_shared<AbstractScalar>(!res.first);
200 }
201 
IsInDict(const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)202 bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list) {
203   constexpr size_t kInputsNum = 2;
204   const std::string op_name = primitive->name();
205   CheckArgsSize(op_name, args_abs_list, kInputsNum);
206   const auto &key = args_abs_list[0];
207   auto dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 1);
208 
209   ValuePtr key_value = key->BuildValue();
210   MS_EXCEPTION_IF_NULL(key_value);
211   std::vector<AbstractElementPair> dict_elems = dict->elements();
212   auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
213     return *key_value == *item.first->BuildValue();
214   });
215   return it != dict_elems.end();
216 }
217 
InferImplInDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)218 AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
219                                 const AbstractBasePtrList &args_abs_list) {
220   // Statement: x in t
221   // Inputs: x, t
222   return std::make_shared<AbstractScalar>(IsInDict(primitive, args_abs_list));
223 }
224 
InferImplNotInDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)225 AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
226                                    const AbstractBasePtrList &args_abs_list) {
227   // Statement: x not in t
228   // Inputs: x, t
229   return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_abs_list));
230 }
231 
InferImplIsConstant(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)232 AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
233                                     const AbstractBasePtrList &args_abs_list) {
234   // Statement: isconstant(x)
235   // Inputs: x
236   if (args_abs_list.size() != 1) {
237     MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1";
238   }
239   ValuePtr v = args_abs_list[0]->BuildValue();
240   return std::make_shared<AbstractScalar>(!v->ContainsValueAny());
241 }
242 }  // namespace abstract
243 }  // namespace mindspore
244