• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "abstract/infer_functions.h"
20 #include "abstract/utils.h"
21 #include "abstract/param_validator.h"
22 
23 namespace mindspore {
24 namespace abstract {
InferImplMakeTuple(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)25 AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
26                                    const AbstractBasePtrList &args_spec_list) {
27   return std::make_shared<AbstractTuple>(args_spec_list);
28 }
29 
InferImplMakeList(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)30 AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
31                                   const AbstractBasePtrList &args_spec_list) {
32   return std::make_shared<AbstractList>(args_spec_list);
33 }
34 
InferImplMakeDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)35 AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
36                                   const AbstractBasePtrList &args_spec_list) {
37   // Inputs: two tuples.
38   const std::string op_name = primitive->name();
39   CheckArgsSize(op_name, args_spec_list, 2);
40   AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
41   AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
42 
43   size_t keys_size = keys->size();
44   if (values->size() != keys_size) {
45     MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
46   }
47 
48   std::vector<AbstractAttribute> key_value;
49   AbstractScalarPtr key;
50   AbstractBasePtrList key_list = keys->elements();
51   AbstractBasePtrList value_list = values->elements();
52   for (size_t index = 0; index < keys_size; index++) {
53     key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
54     ValuePtr keyPtr = key->BuildValue();
55     MS_EXCEPTION_IF_NULL(keyPtr);
56     if (!keyPtr->isa<StringImm>()) {
57       MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
58     }
59     auto key_string = GetValue<std::string>(keyPtr);
60     (void)key_value.emplace_back(key_string, value_list[index]);
61   }
62   return std::make_shared<AbstractDictionary>(key_value);
63 }
64 
InferImplMakeKwarg(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)65 AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
66                                    const AbstractBasePtrList &args_spec_list) {
67   // Inputs: a string and an object of a subclass of AbstractBase.
68   const std::string op_name = primitive->name();
69   CheckArgsSize(op_name, args_spec_list, 2);
70   AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
71 
72   ValuePtr keyPtr = key->BuildValue();
73   MS_EXCEPTION_IF_NULL(keyPtr);
74   if (!keyPtr->isa<StringImm>()) {
75     MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
76   }
77   auto key_string = GetValue<std::string>(keyPtr);
78   return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
79 }
80 
InferImplExtractKwarg(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)81 AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
82                                       const AbstractBasePtrList &args_spec_list) {
83   // Inputs: a string and a keyword.
84   const std::string op_name = primitive->name();
85   CheckArgsSize(op_name, args_spec_list, 2);
86   AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
87   AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
88 
89   ValuePtr key_value = key->BuildValue();
90   MS_EXCEPTION_IF_NULL(key_value);
91   if (!key_value->isa<StringImm>()) {
92     MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
93   }
94   auto key_input = GetValue<std::string>(key_value);
95   std::string key_actual = kwarg->get_key();
96   if (key_actual != key_input) {
97     MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
98                       << key_input << ", AbstractKeywordArg' key is " << key_actual;
99   }
100   return kwarg->get_arg();
101 }
102 
InferImplMakeSlice(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)103 AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
104                                    const AbstractBasePtrList &args_spec_list) {
105   // Inputs: three scalars whose value is an int32 number.
106   CheckArgsSize(primitive->name(), args_spec_list, 3);
107   size_t args_size = args_spec_list.size();
108   AbstractBasePtrList slice_args;
109   for (size_t index = 0; index < args_size; index++) {
110     MS_EXCEPTION_IF_NULL(args_spec_list[index]);
111     if (args_spec_list[index]->isa<AbstractNone>()) {
112       slice_args.push_back(args_spec_list[index]);
113     } else if (args_spec_list[index]->isa<AbstractScalar>()) {
114       ValuePtr scalar_value = args_spec_list[index]->cast<AbstractScalarPtr>()->BuildValue();
115       MS_EXCEPTION_IF_NULL(scalar_value);
116       if (scalar_value->isa<IntergerImm>()) {
117         slice_args.push_back(args_spec_list[index]);
118       } else if (scalar_value->isa<BoolImm>()) {
119         ValuePtr scalar_index = MakeValue(static_cast<int64_t>(scalar_value->cast<BoolImmPtr>()->value()));
120         slice_args.push_back(scalar_index->ToAbstract());
121       } else {
122         MS_EXCEPTION(TypeError) << "MakeSlice eval " << index
123                                 << " the input scalar type should be int or bool, but got " << scalar_value->ToString();
124       }
125     } else if (args_spec_list[index]->isa<AbstractTensor>()) {
126       auto arg = args_spec_list[index]->cast<AbstractTensorPtr>();
127       TypePtr tensor_dtype = arg->element()->BuildType();
128       auto build_value = arg->BuildValue();
129       MS_EXCEPTION_IF_NULL(build_value);
130       auto value = build_value->cast<tensor::TensorPtr>();
131       if (value == nullptr) {
132         MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must be a const tensor.";
133       }
134       if (value->DataSize() != 1) {
135         MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must contain only one element, but got "
136                                 << value->ToString() << " has " << value->DataSize() << " elements.";
137       }
138 
139       if (tensor_dtype->isa<Bool>()) {
140         auto *bool_value = static_cast<bool *>(value->data_c());
141         slice_args.push_back(MakeValue((static_cast<int64_t>(*bool_value)))->ToAbstract());
142       } else if (tensor_dtype->isa<Int>()) {
143         auto *int_value = static_cast<int64_t *>(value->data_c());
144         slice_args.push_back(MakeValue((*int_value))->ToAbstract());
145       } else {
146         MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor type must be int or bool, but got "
147                                 << tensor_dtype->ToString();
148       }
149     } else {
150       MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " inputs should scalar, None or Tensor, but got"
151                               << args_spec_list[index]->ToString();
152     }
153   }
154   // Slice: start, end, step
155   return std::make_shared<AbstractSlice>(slice_args[0], slice_args[1], slice_args[2]);
156 }
157 
158 template <typename T>
InferTupleOrListGetItem(const std::string & op_name,const AbstractBasePtrList & args_spec_list)159 AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
160   // Inputs: a tuple or list and a scalar whose value is an int32 number.
161   CheckArgsSize(op_name, args_spec_list, 2);
162   auto queue = CheckArg<T>(op_name, args_spec_list, 0);
163   AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
164 
165   ValuePtr index_value = index->BuildValue();
166   MS_EXCEPTION_IF_NULL(index_value);
167   if (!index_value->isa<Int64Imm>()) {
168     // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
169     //  and continue
170     if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
171       return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
172     }
173     MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " << index->ToString();
174   }
175   auto idx_v = GetValue<int64_t>(index_value);
176   std::size_t nelems = queue->elements().size();
177   if (idx_v >= SizeToLong(nelems) || idx_v < -SizeToLong(nelems)) {
178     MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToLong(nelems) << ", "
179                              << SizeToLong(nelems) << "), but got " << idx_v << ".";
180   }
181 
182   std::size_t uidx_v = 0;
183   if (idx_v >= 0) {
184     uidx_v = LongToSize(idx_v);
185   } else {
186     uidx_v = LongToSize(idx_v + SizeToLong(nelems));
187   }
188   return queue->elements()[uidx_v];
189 }
190 
191 template <typename T>
InferTupleOrListSetItem(const std::string & op_name,const AbstractBasePtrList & args_spec_list)192 AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
193   // Inputs: a tuple or list, a scalar whose value is an int64 number and an object of a subclass of AbstractBase.
194   CheckArgsSize(op_name, args_spec_list, 3);
195   auto queue = CheckArg<T>(op_name, args_spec_list, 0);
196   AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
197 
198   ValuePtr index_value = index->BuildValue();
199   MS_EXCEPTION_IF_NULL(index_value);
200   if (!index_value->isa<Int64Imm>()) {
201     MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got "
202                              << index_value->ToString();
203   }
204   auto idx_v = GetValue<int64_t>(index_value);
205   AbstractBasePtrList elements = queue->elements();
206   std::size_t nelems = elements.size();
207   int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems);
208   if (idx_t < 0 || idx_t >= SizeToLong(nelems)) {
209     MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << idx_v << " to set out of range: [-" << nelems
210                              << "," << (nelems - 1) << "].";
211   }
212   size_t uidx_v = LongToSize(idx_t);
213   elements[uidx_v] = args_spec_list[2];
214   return std::make_shared<T>(elements);
215 }
216 
InferImplTupleGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)217 AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
218                                       const AbstractBasePtrList &args_spec_list) {
219   return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
220 }
221 
InferImplListGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)222 AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
223                                      const AbstractBasePtrList &args_spec_list) {
224   return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
225 }
226 
InferImplTupleSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)227 AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
228                                       const AbstractBasePtrList &args_spec_list) {
229   return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
230 }
231 
InferImplListSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)232 AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
233                                      const AbstractBasePtrList &args_spec_list) {
234   return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
235 }
236 
InferImplDictGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)237 AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
238                                      const AbstractBasePtrList &args_spec_list) {
239   // Inputs: a dict and a scalar whose value is a string.
240   const std::string op_name = primitive->name();
241   CheckArgsSize(op_name, args_spec_list, 2);
242   AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
243   AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
244 
245   ValuePtr key_value = key->BuildValue();
246   MS_EXCEPTION_IF_NULL(key_value);
247   if (!key_value->isa<StringImm>()) {
248     MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
249   }
250   auto key_str = GetValue<std::string>(key_value);
251   std::vector<AbstractAttribute> dict_elems = dict->elements();
252   auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
253                          [key_str](const AbstractAttribute &item) { return item.first == key_str; });
254   if (it == dict_elems.end()) {
255     MS_EXCEPTION(KeyError) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
256   }
257   return it->second;
258 }
259 
InferImplDictSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)260 AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
261                                      const AbstractBasePtrList &args_spec_list) {
262   // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
263   const std::string op_name = primitive->name();
264   CheckArgsSize(op_name, args_spec_list, 3);
265   AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
266   AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
267 
268   ValuePtr key_value = key->BuildValue();
269   MS_EXCEPTION_IF_NULL(key_value);
270   if (!key_value->isa<StringImm>()) {
271     MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
272   }
273   auto key_str = GetValue<std::string>(key_value);
274   std::vector<AbstractAttribute> dict_elems = dict->elements();
275   auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
276                          [key_str](const AbstractAttribute &item) { return item.first == key_str; });
277 
278   MS_EXCEPTION_IF_NULL(args_spec_list[2]);
279   auto new_ele = std::make_pair(key_str, args_spec_list[2]);
280   if (it != dict_elems.end()) {
281     int64_t index = it - dict_elems.begin();
282     dict_elems[LongToSize(index)] = new_ele;
283   } else {
284     dict_elems.push_back(new_ele);
285   }
286   return std::make_shared<AbstractDictionary>(dict_elems);
287 }
288 
InferImplDictGetKeys(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)289 AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
290                                      const AbstractBasePtrList &args_spec_list) {
291   // Inputs: a dict.
292   const std::string op_name = primitive->name();
293   CheckArgsSize(op_name, args_spec_list, 1);
294   AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
295   std::vector<AbstractAttribute> dict_elems = dict->elements();
296   AbstractBasePtrList keys;
297   std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(keys),
298                  [](const AbstractAttribute &item) { return std::make_shared<AbstractScalar>(item.first); });
299   return std::make_shared<AbstractTuple>(keys);
300 }
301 
InferImplDictGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)302 AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
303                                        const AbstractBasePtrList &args_spec_list) {
304   // Inputs: a dict.
305   const std::string op_name = primitive->name();
306   CheckArgsSize(op_name, args_spec_list, 1);
307   AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
308   std::vector<AbstractAttribute> dict_elems = dict->elements();
309   AbstractBasePtrList values;
310   std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(values),
311                  [](const AbstractAttribute &item) { return item.second; });
312   return std::make_shared<AbstractTuple>(values);
313 }
314 
InferImplListAppend(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)315 AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
316                                     const AbstractBasePtrList &args_spec_list) {
317   // Inputs: a list and an object of a subclass of AbstractBase.
318   const std::string op_name = primitive->name();
319   CheckArgsSize(op_name, args_spec_list, 2);
320   AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
321   AbstractBasePtr item = dyn_cast<AbstractBase>(args_spec_list[1]);
322   MS_EXCEPTION_IF_NULL(item);
323   auto new_list = AbstractBasePtrList(list->elements());
324   new_list.emplace_back(item);
325   return std::make_shared<AbstractList>(new_list);
326 }
327 
InferImplTupleLen(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)328 AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
329                                   const AbstractBasePtrList &args_spec_list) {
330   return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
331 }
332 
InferImplListLen(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)333 AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
334                                  const AbstractBasePtrList &args_spec_list) {
335   return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
336 }
337 
InferImplArrayLen(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)338 AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
339                                   const AbstractBasePtrList &args_spec_list) {
340   const std::string op_name = primitive->name();
341   CheckArgsSize(op_name, args_spec_list, 1);
342   auto arg_abs = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
343   auto shape = arg_abs->BuildShape()->cast<ShapePtr>();
344   MS_EXCEPTION_IF_NULL(shape);
345   if (shape->shape().empty()) {
346     MS_EXCEPTION(TypeError) << "Not support len of a 0-D tensor.";
347   }
348   return std::make_shared<AbstractScalar>(shape->shape()[0]);
349 }
350 }  // namespace abstract
351 }  // namespace mindspore
352