• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "abstract/param_validator.h"
18 #include "abstract/infer_functions.h"
19 #include "abstract/abstract_function.h"
20 #include "abstract/utils.h"
21 #include "utils/symbolic.h"
22 
23 namespace mindspore {
24 namespace abstract {
InferImplReturn(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)25 AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
26                                 const AbstractBasePtrList &args_spec_list) {
27   // Inputs: a pointer to an AbstractBase object
28   if (args_spec_list.size() != 1) {
29     MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? "
30                     "while the input size is "
31                  << args_spec_list.size() << ".";
32   }
33   AbstractBasePtr abs_base = args_spec_list[0];
34   return abs_base;
35 }
36 
InferImplSwitch(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)37 AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
38                                 const AbstractBasePtrList &args_spec_list) {
39   // Inputs: condition, true branch, false branch
40   if (args_spec_list.size() != 3) {
41     MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
42                       << ".";
43   }
44 
45   auto cond = args_spec_list[0];
46   auto tb = args_spec_list[1];
47   auto fb = args_spec_list[2];
48   MS_EXCEPTION_IF_NULL(cond);
49 
50   ValuePtr v = cond->GetValueTrack();
51   MS_EXCEPTION_IF_NULL(v);
52   // for tensor as condition, keeps both true and false branch.
53   if (v->isa<AnyValue>() || cond->isa<AbstractTensor>()) {
54     MS_EXCEPTION_IF_NULL(tb);
55     return tb->Join(fb);
56   }
57 
58   if (v->isa<Scalar>()) {
59     if (v->cast<ScalarPtr>()->IsOne()) {
60       return tb;
61     } else {
62       return fb;
63     }
64   }
65 
66   MS_LOG(EXCEPTION) << "Not support this condition value: " << cond->GetValueTrack()->ToString();
67 }
68 
InferImplSwitchLayer(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)69 AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
70                                      const AbstractBasePtrList &args_spec_list) {
71   // Inputs: {index, MakeTuple{branch1,branch2,branch3....}}
72   constexpr auto kSwitchLayerInputNum = 2;
73   const std::string op_name = primitive->name();
74   abstract::CheckArgsSize(op_name, args_spec_list, kSwitchLayerInputNum);
75   auto index = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
76   auto &input_shape = index->shape()->shape();
77   if (!input_shape.empty()) {
78     MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size()
79                              << " dimension tensor";
80   }
81   auto dtype = index->element()->BuildType();
82   if (dtype->type_id() != kInt32->type_id()) {
83     MS_EXCEPTION(ValueError) << op_name << " index must be a int32, but got " << dtype->ToString();
84   }
85 
86   AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
87   AbstractBasePtrList branches = branches_abs->elements();
88   const size_t maximum_layer_num = 1000;
89   if (branches.empty() || branches.size() > maximum_layer_num) {
90     MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
91                              << branches.size() << " branches.";
92   }
93 
94   for (size_t i = 0; i < branches.size(); i++) {
95     MS_EXCEPTION_IF_NULL(branches[i]);
96     if (!branches[i]->isa<FuncGraphAbstractClosure>() && !branches[i]->isa<PartialAbstractClosure>()) {
97       MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
98                                << branches[i]->ToString() << " as the " << i << "th element.";
99     }
100   }
101 
102   auto b = branches[0];
103   // Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0]
104   // which will cancel the out of bound checking for index
105   if (branches.size() == 1) {
106     AbstractFuncAtomPtrList func_list{b->cast<AbstractFuncAtomPtr>()};
107     return std::make_shared<AbstractFuncUnion>(func_list);
108   }
109   for (size_t i = 1; i < branches.size(); i++) {
110     b = b->Join(branches[i]);
111   }
112   return b;
113 }
114 
GetSupportedTargetValue()115 std::vector<ValuePtr> GetSupportedTargetValue() {
116   std::vector<ValuePtr> list = {kNone, MakeValue(false), MakeValue(true)};
117   return list;
118 }
119 
SupportedIsTargetValue(const ValuePtr t)120 bool SupportedIsTargetValue(const ValuePtr t) {
121   auto list = GetSupportedTargetValue();
122   auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; });
123   return match;
124 }
125 
InferImplIs_(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)126 AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
127                              const AbstractBasePtrList &args_spec_list) {
128   // statement: x is t
129   // Inputs: x, t
130   const std::string op_name = primitive->name();
131   CheckArgsSize(op_name, args_spec_list, 2);
132   ValuePtr t = args_spec_list[1]->BuildValue();
133   if (!SupportedIsTargetValue(t)) {
134     MS_LOG(EXCEPTION) << "This comparator '" << t->ToString()
135                       << "' is not supported. For statement 'is', only support compare with 'None', 'False' or 'True'";
136   }
137   ValuePtr x = args_spec_list[0]->BuildValue();
138 
139   return std::make_shared<AbstractScalar>(*t == *x);
140 }
141 
InferImplIsNot(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)142 AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
143                                const AbstractBasePtrList &args_spec_list) {
144   // statement: x is not t
145   // Inputs: x, t
146   const std::string op_name = primitive->name();
147   CheckArgsSize(op_name, args_spec_list, 2);
148   ValuePtr t = args_spec_list[1]->BuildValue();
149   if (!SupportedIsTargetValue(t)) {
150     MS_LOG(EXCEPTION)
151       << "This comparator '" << t->ToString()
152       << "' is not supported. For statement 'is not' , only support compare with 'None', 'False' or 'True'";
153   }
154   ValuePtr x = args_spec_list[0]->BuildValue();
155 
156   return std::make_shared<AbstractScalar>(!(*t == *x));
157 }
158 
IsInDict(const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)159 bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
160   const std::string op_name = primitive->name();
161   CheckArgsSize(op_name, args_spec_list, 2);
162   auto key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
163   auto dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 1);
164 
165   ValuePtr key_value = key->BuildValue();
166   if (!key_value->isa<StringImm>()) {
167     MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
168   }
169   auto key_str = GetValue<std::string>(key_value);
170   std::vector<AbstractAttribute> dict_elems = dict->elements();
171   auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
172                          [key_str](const AbstractAttribute &item) { return item.first == key_str; });
173   return it != dict_elems.end();
174 }
175 
InferImplInDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)176 AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
177                                 const AbstractBasePtrList &args_spec_list) {
178   // statement: x in t
179   // Inputs: x, t
180   return std::make_shared<AbstractScalar>(IsInDict(primitive, args_spec_list));
181 }
182 
InferImplNotInDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)183 AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
184                                    const AbstractBasePtrList &args_spec_list) {
185   // statement: x not in t
186   // Inputs: x, t
187   return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
188 }
189 
InferImplIsConstant(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)190 AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
191                                     const AbstractBasePtrList &args_spec_list) {
192   // statement: isconstant(x)
193   // Inputs: x
194   if (args_spec_list.size() != 1) {
195     MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1";
196   }
197   ValuePtr v = args_spec_list[0]->BuildValue();
198   return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
199 }
200 }  // namespace abstract
201 }  // namespace mindspore
202