1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "abstract/infer_functions.h"
20 #include "abstract/utils.h"
21 #include "abstract/param_validator.h"
22
23 namespace mindspore {
24 namespace abstract {
InferImplMakeTuple(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)25 AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
26 const AbstractBasePtrList &args_spec_list) {
27 return std::make_shared<AbstractTuple>(args_spec_list);
28 }
29
InferImplMakeList(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)30 AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
31 const AbstractBasePtrList &args_spec_list) {
32 return std::make_shared<AbstractList>(args_spec_list);
33 }
34
InferImplMakeDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)35 AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
36 const AbstractBasePtrList &args_spec_list) {
37 // Inputs: two tuples.
38 const std::string op_name = primitive->name();
39 CheckArgsSize(op_name, args_spec_list, 2);
40 AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
41 AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
42
43 size_t keys_size = keys->size();
44 if (values->size() != keys_size) {
45 MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
46 }
47
48 std::vector<AbstractAttribute> key_value;
49 AbstractScalarPtr key;
50 AbstractBasePtrList key_list = keys->elements();
51 AbstractBasePtrList value_list = values->elements();
52 for (size_t index = 0; index < keys_size; index++) {
53 key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
54 ValuePtr keyPtr = key->BuildValue();
55 MS_EXCEPTION_IF_NULL(keyPtr);
56 if (!keyPtr->isa<StringImm>()) {
57 MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
58 }
59 auto key_string = GetValue<std::string>(keyPtr);
60 (void)key_value.emplace_back(key_string, value_list[index]);
61 }
62 return std::make_shared<AbstractDictionary>(key_value);
63 }
64
InferImplMakeKwarg(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)65 AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
66 const AbstractBasePtrList &args_spec_list) {
67 // Inputs: a string and an object of a subclass of AbstractBase.
68 const std::string op_name = primitive->name();
69 CheckArgsSize(op_name, args_spec_list, 2);
70 AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
71
72 ValuePtr keyPtr = key->BuildValue();
73 MS_EXCEPTION_IF_NULL(keyPtr);
74 if (!keyPtr->isa<StringImm>()) {
75 MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
76 }
77 auto key_string = GetValue<std::string>(keyPtr);
78 return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
79 }
80
InferImplExtractKwarg(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)81 AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
82 const AbstractBasePtrList &args_spec_list) {
83 // Inputs: a string and a keyword.
84 const std::string op_name = primitive->name();
85 CheckArgsSize(op_name, args_spec_list, 2);
86 AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
87 AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
88
89 ValuePtr key_value = key->BuildValue();
90 MS_EXCEPTION_IF_NULL(key_value);
91 if (!key_value->isa<StringImm>()) {
92 MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
93 }
94 auto key_input = GetValue<std::string>(key_value);
95 std::string key_actual = kwarg->get_key();
96 if (key_actual != key_input) {
97 MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
98 << key_input << ", AbstractKeywordArg' key is " << key_actual;
99 }
100 return kwarg->get_arg();
101 }
102
InferImplMakeSlice(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)103 AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
104 const AbstractBasePtrList &args_spec_list) {
105 // Inputs: three scalars whose value is an int32 number.
106 CheckArgsSize(primitive->name(), args_spec_list, 3);
107 size_t args_size = args_spec_list.size();
108 AbstractBasePtrList slice_args;
109 for (size_t index = 0; index < args_size; index++) {
110 MS_EXCEPTION_IF_NULL(args_spec_list[index]);
111 if (args_spec_list[index]->isa<AbstractNone>()) {
112 slice_args.push_back(args_spec_list[index]);
113 } else if (args_spec_list[index]->isa<AbstractScalar>()) {
114 ValuePtr scalar_value = args_spec_list[index]->cast<AbstractScalarPtr>()->BuildValue();
115 MS_EXCEPTION_IF_NULL(scalar_value);
116 if (scalar_value->isa<IntergerImm>()) {
117 slice_args.push_back(args_spec_list[index]);
118 } else if (scalar_value->isa<BoolImm>()) {
119 ValuePtr scalar_index = MakeValue(static_cast<int64_t>(scalar_value->cast<BoolImmPtr>()->value()));
120 slice_args.push_back(scalar_index->ToAbstract());
121 } else {
122 MS_EXCEPTION(TypeError) << "MakeSlice eval " << index
123 << " the input scalar type should be int or bool, but got " << scalar_value->ToString();
124 }
125 } else if (args_spec_list[index]->isa<AbstractTensor>()) {
126 auto arg = args_spec_list[index]->cast<AbstractTensorPtr>();
127 TypePtr tensor_dtype = arg->element()->BuildType();
128 auto build_value = arg->BuildValue();
129 MS_EXCEPTION_IF_NULL(build_value);
130 auto value = build_value->cast<tensor::TensorPtr>();
131 if (value == nullptr) {
132 MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must be a const tensor.";
133 }
134 if (value->DataSize() != 1) {
135 MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must contain only one element, but got "
136 << value->ToString() << " has " << value->DataSize() << " elements.";
137 }
138
139 if (tensor_dtype->isa<Bool>()) {
140 auto *bool_value = static_cast<bool *>(value->data_c());
141 slice_args.push_back(MakeValue((static_cast<int64_t>(*bool_value)))->ToAbstract());
142 } else if (tensor_dtype->isa<Int>()) {
143 auto *int_value = static_cast<int64_t *>(value->data_c());
144 slice_args.push_back(MakeValue((*int_value))->ToAbstract());
145 } else {
146 MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor type must be int or bool, but got "
147 << tensor_dtype->ToString();
148 }
149 } else {
150 MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " inputs should scalar, None or Tensor, but got"
151 << args_spec_list[index]->ToString();
152 }
153 }
154 // Slice: start, end, step
155 return std::make_shared<AbstractSlice>(slice_args[0], slice_args[1], slice_args[2]);
156 }
157
158 template <typename T>
InferTupleOrListGetItem(const std::string & op_name,const AbstractBasePtrList & args_spec_list)159 AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
160 // Inputs: a tuple or list and a scalar whose value is an int32 number.
161 CheckArgsSize(op_name, args_spec_list, 2);
162 auto queue = CheckArg<T>(op_name, args_spec_list, 0);
163 AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
164
165 ValuePtr index_value = index->BuildValue();
166 MS_EXCEPTION_IF_NULL(index_value);
167 if (!index_value->isa<Int64Imm>()) {
168 // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
169 // and continue
170 if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
171 return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
172 }
173 MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " << index->ToString();
174 }
175 auto idx_v = GetValue<int64_t>(index_value);
176 std::size_t nelems = queue->elements().size();
177 if (idx_v >= SizeToLong(nelems) || idx_v < -SizeToLong(nelems)) {
178 MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToLong(nelems) << ", "
179 << SizeToLong(nelems) << "), but got " << idx_v << ".";
180 }
181
182 std::size_t uidx_v = 0;
183 if (idx_v >= 0) {
184 uidx_v = LongToSize(idx_v);
185 } else {
186 uidx_v = LongToSize(idx_v + SizeToLong(nelems));
187 }
188 return queue->elements()[uidx_v];
189 }
190
191 template <typename T>
InferTupleOrListSetItem(const std::string & op_name,const AbstractBasePtrList & args_spec_list)192 AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
193 // Inputs: a tuple or list, a scalar whose value is an int64 number and an object of a subclass of AbstractBase.
194 CheckArgsSize(op_name, args_spec_list, 3);
195 auto queue = CheckArg<T>(op_name, args_spec_list, 0);
196 AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
197
198 ValuePtr index_value = index->BuildValue();
199 MS_EXCEPTION_IF_NULL(index_value);
200 if (!index_value->isa<Int64Imm>()) {
201 MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got "
202 << index_value->ToString();
203 }
204 auto idx_v = GetValue<int64_t>(index_value);
205 AbstractBasePtrList elements = queue->elements();
206 std::size_t nelems = elements.size();
207 int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems);
208 if (idx_t < 0 || idx_t >= SizeToLong(nelems)) {
209 MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << idx_v << " to set out of range: [-" << nelems
210 << "," << (nelems - 1) << "].";
211 }
212 size_t uidx_v = LongToSize(idx_t);
213 elements[uidx_v] = args_spec_list[2];
214 return std::make_shared<T>(elements);
215 }
216
InferImplTupleGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)217 AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
218 const AbstractBasePtrList &args_spec_list) {
219 return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
220 }
221
InferImplListGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)222 AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
223 const AbstractBasePtrList &args_spec_list) {
224 return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
225 }
226
InferImplTupleSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)227 AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
228 const AbstractBasePtrList &args_spec_list) {
229 return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
230 }
231
InferImplListSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)232 AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
233 const AbstractBasePtrList &args_spec_list) {
234 return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
235 }
236
InferImplDictGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)237 AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
238 const AbstractBasePtrList &args_spec_list) {
239 // Inputs: a dict and a scalar whose value is a string.
240 const std::string op_name = primitive->name();
241 CheckArgsSize(op_name, args_spec_list, 2);
242 AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
243 AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
244
245 ValuePtr key_value = key->BuildValue();
246 MS_EXCEPTION_IF_NULL(key_value);
247 if (!key_value->isa<StringImm>()) {
248 MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
249 }
250 auto key_str = GetValue<std::string>(key_value);
251 std::vector<AbstractAttribute> dict_elems = dict->elements();
252 auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
253 [key_str](const AbstractAttribute &item) { return item.first == key_str; });
254 if (it == dict_elems.end()) {
255 MS_EXCEPTION(KeyError) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
256 }
257 return it->second;
258 }
259
InferImplDictSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)260 AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
261 const AbstractBasePtrList &args_spec_list) {
262 // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
263 const std::string op_name = primitive->name();
264 CheckArgsSize(op_name, args_spec_list, 3);
265 AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
266 AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
267
268 ValuePtr key_value = key->BuildValue();
269 MS_EXCEPTION_IF_NULL(key_value);
270 if (!key_value->isa<StringImm>()) {
271 MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
272 }
273 auto key_str = GetValue<std::string>(key_value);
274 std::vector<AbstractAttribute> dict_elems = dict->elements();
275 auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
276 [key_str](const AbstractAttribute &item) { return item.first == key_str; });
277
278 MS_EXCEPTION_IF_NULL(args_spec_list[2]);
279 auto new_ele = std::make_pair(key_str, args_spec_list[2]);
280 if (it != dict_elems.end()) {
281 int64_t index = it - dict_elems.begin();
282 dict_elems[LongToSize(index)] = new_ele;
283 } else {
284 dict_elems.push_back(new_ele);
285 }
286 return std::make_shared<AbstractDictionary>(dict_elems);
287 }
288
InferImplDictGetKeys(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)289 AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
290 const AbstractBasePtrList &args_spec_list) {
291 // Inputs: a dict.
292 const std::string op_name = primitive->name();
293 CheckArgsSize(op_name, args_spec_list, 1);
294 AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
295 std::vector<AbstractAttribute> dict_elems = dict->elements();
296 AbstractBasePtrList keys;
297 std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(keys),
298 [](const AbstractAttribute &item) { return std::make_shared<AbstractScalar>(item.first); });
299 return std::make_shared<AbstractTuple>(keys);
300 }
301
InferImplDictGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)302 AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
303 const AbstractBasePtrList &args_spec_list) {
304 // Inputs: a dict.
305 const std::string op_name = primitive->name();
306 CheckArgsSize(op_name, args_spec_list, 1);
307 AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
308 std::vector<AbstractAttribute> dict_elems = dict->elements();
309 AbstractBasePtrList values;
310 std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(values),
311 [](const AbstractAttribute &item) { return item.second; });
312 return std::make_shared<AbstractTuple>(values);
313 }
314
InferImplListAppend(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)315 AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
316 const AbstractBasePtrList &args_spec_list) {
317 // Inputs: a list and an object of a subclass of AbstractBase.
318 const std::string op_name = primitive->name();
319 CheckArgsSize(op_name, args_spec_list, 2);
320 AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
321 AbstractBasePtr item = dyn_cast<AbstractBase>(args_spec_list[1]);
322 MS_EXCEPTION_IF_NULL(item);
323 auto new_list = AbstractBasePtrList(list->elements());
324 new_list.emplace_back(item);
325 return std::make_shared<AbstractList>(new_list);
326 }
327
InferImplTupleLen(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)328 AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
329 const AbstractBasePtrList &args_spec_list) {
330 return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
331 }
332
InferImplListLen(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)333 AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
334 const AbstractBasePtrList &args_spec_list) {
335 return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
336 }
337
InferImplArrayLen(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)338 AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
339 const AbstractBasePtrList &args_spec_list) {
340 const std::string op_name = primitive->name();
341 CheckArgsSize(op_name, args_spec_list, 1);
342 auto arg_abs = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
343 auto shape = arg_abs->BuildShape()->cast<ShapePtr>();
344 MS_EXCEPTION_IF_NULL(shape);
345 if (shape->shape().empty()) {
346 MS_EXCEPTION(TypeError) << "Not support len of a 0-D tensor.";
347 }
348 return std::make_shared<AbstractScalar>(shape->shape()[0]);
349 }
350 } // namespace abstract
351 } // namespace mindspore
352