1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "abstract/ops/infer_functions.h"
20 #include "abstract/utils.h"
21 #include "abstract/param_validator.h"
22 #include "utils/check_convert_utils.h"
23 #include "include/common/utils/utils.h"
24
25 namespace mindspore {
26 namespace abstract {
27 namespace {
CheckDictKey(const AbstractBasePtr & key,const std::string & op_name)28 void CheckDictKey(const AbstractBasePtr &key, const std::string &op_name) {
29 auto key_value = key->BuildValue();
30 MS_EXCEPTION_IF_NULL(key_value);
31 if (!(key_value->isa<StringImm>() || key_value->isa<Scalar>() || key_value->isa<Type>() || key_value->isa<None>() ||
32 (key->isa<AbstractTensor>() && !key_value->ContainsValueAny()) || key->isa<AbstractTuple>())) {
33 MS_LOG(EXCEPTION) << op_name << " evaluator key only supports string, number, type, none, "
34 << "constant tensor and tuple, but got " << key->BuildValue()->ToString();
35 }
36 if (key->isa<AbstractTuple>() && key_value->isa<ValueAny>()) {
37 MS_LOG(EXCEPTION) << op_name << " evaluator key should not be tuple that contains variables, but got "
38 << key->BuildValue()->ToString();
39 }
40 }
41 } // namespace
42
ProcessUnpackDict(const AbstractTuplePtr & key_tuple,const AbstractTuplePtr & value_tuple,std::unordered_map<std::string,AbstractBasePtr> * key_str_value_set,std::vector<AbstractBasePtr> * key_set)43 void ProcessUnpackDict(const AbstractTuplePtr &key_tuple, const AbstractTuplePtr &value_tuple,
44 std::unordered_map<std::string, AbstractBasePtr> *key_str_value_set,
45 std::vector<AbstractBasePtr> *key_set) {
46 // The size of need unpack tuple must be 1
47 const auto &key_elements = key_tuple->elements();
48 const auto &value_elements = value_tuple->elements();
49 if (key_elements.size() != 1) {
50 MS_LOG(EXCEPTION) << "The size of need unpack key tuple must be 1, but got " << key_elements.size();
51 }
52 if (value_elements.size() != 1) {
53 MS_LOG(EXCEPTION) << "The size of need unpack value tuple must be 1, but got " << value_elements.size();
54 }
55
56 auto unpack_keys = key_elements[0];
57 auto unpack_keys_tuple = unpack_keys->cast<AbstractTuplePtr>();
58 const auto &unpack_keys_elements = unpack_keys_tuple->elements();
59
60 auto unpack_values = value_elements[0];
61 auto unpack_values_tuple = unpack_values->cast<AbstractTuplePtr>();
62 const auto &unpack_values_elements = unpack_values_tuple->elements();
63
64 if (unpack_keys_elements.size() != unpack_values_elements.size()) {
65 MS_LOG(EXCEPTION) << "The keys' size should be equal to values' size, but the keys' size is "
66 << unpack_keys_elements.size() << ", the values' size is " << unpack_values_elements.size();
67 }
68
69 for (size_t inner_index = 0; inner_index < unpack_keys_elements.size(); ++inner_index) {
70 auto inner_key = unpack_keys_elements[inner_index];
71 auto key_str = inner_key->BuildValue()->ToString();
72 (void)key_str_value_set->emplace(key_str, unpack_values_elements[inner_index]);
73 (void)key_set->emplace_back(inner_key);
74 }
75 }
76
InferImplMakeDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)77 AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
78 const AbstractBasePtrList &args_abs_list) {
79 // Inputs: two tuples.
80 const std::string op_name = primitive->name();
81 constexpr int args_spec_size = 2;
82 CheckArgsSize(op_name, args_abs_list, args_spec_size);
83 AbstractSequencePtr keys = CheckArg<AbstractSequence>(op_name, args_abs_list, 0);
84 AbstractSequencePtr values = CheckArg<AbstractSequence>(op_name, args_abs_list, 1);
85
86 size_t keys_size = keys->size();
87 if (values->size() != keys_size) {
88 MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
89 }
90
91 AbstractBasePtrList key_list = keys->elements();
92 std::unordered_map<std::string, AbstractBasePtr> key_str_value_set;
93 std::vector<AbstractBasePtr> key_set;
94 std::vector<AbstractElementPair> key_value;
95 AbstractBasePtrList value_list = values->elements();
96 constexpr auto need_unpack = "need_unpack";
97 for (size_t index = 0; index < keys_size; index++) {
98 const auto &key = key_list[index];
99 bool is_need_unpack = false;
100 if (key->isa<AbstractTuple>()) {
101 auto key_tuple = key->cast<AbstractTuplePtr>();
102 if (key_tuple->HasData(need_unpack)) {
103 is_need_unpack = *key_tuple->GetData<bool>(need_unpack);
104 if (is_need_unpack) {
105 auto value_tuple = value_list[index]->cast<AbstractTuplePtr>();
106 MS_EXCEPTION_IF_NULL(value_tuple);
107 ProcessUnpackDict(key_tuple, value_tuple, &key_str_value_set, &key_set);
108 }
109 }
110 }
111 CheckDictKey(key, op_name);
112 auto key_val = key->BuildValue()->ToString();
113 auto iter = key_str_value_set.find(key_val);
114 // Remove duplicate keys.
115 // {Tensor[1]: x, Tensor[1}: y} the length of dict is 2, means the two keys are not duplicate.
116 if (iter != key_str_value_set.end() && !key->isa<AbstractTensor>()) {
117 iter->second = value_list[index];
118 } else if (!is_need_unpack) {
119 auto key_str = key->BuildValue()->ToString();
120 key_str_value_set.insert(std::pair<std::string, AbstractBasePtr>(key_str, value_list[index]));
121 key_set.push_back(key);
122 }
123 }
124 for (auto &key : key_set) {
125 auto key_str = key->BuildValue()->ToString();
126 (void)key_value.emplace_back(key, key_str_value_set[key_str]);
127 }
128 return std::make_shared<AbstractDictionary>(key_value);
129 }
130
InferImplMakeKeywordArg(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)131 AbstractBasePtr InferImplMakeKeywordArg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
132 const AbstractBasePtrList &args_abs_list) {
133 // Inputs: a string and an object of a subclass of AbstractBase.
134 const std::string op_name = primitive->name();
135 constexpr int args_spec_size = 2;
136 CheckArgsSize(op_name, args_abs_list, args_spec_size);
137 AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_abs_list, 0);
138
139 ValuePtr keyPtr = key->BuildValue();
140 MS_EXCEPTION_IF_NULL(keyPtr);
141 if (!keyPtr->isa<StringImm>()) {
142 MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
143 }
144 auto key_string = GetValue<std::string>(keyPtr);
145 return std::make_shared<AbstractKeywordArg>(key_string, args_abs_list[1]);
146 }
147
InferImplExtractKeywordArg(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)148 AbstractBasePtr InferImplExtractKeywordArg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
149 const AbstractBasePtrList &args_abs_list) {
150 // Inputs: a key and a Keyword or only a Keyword.
151 const std::string op_name = primitive->name();
152 constexpr int only_kw_input_size = 1;
153 constexpr int check_key_input_size = 2;
154 AbstractKeywordArgPtr kwarg = nullptr;
155 if (args_abs_list.size() == check_key_input_size) {
156 AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_abs_list, 0);
157 kwarg = CheckArg<AbstractKeywordArg>(op_name, args_abs_list, 1);
158
159 ValuePtr key_value = key->BuildValue();
160 MS_EXCEPTION_IF_NULL(key_value);
161 if (!key_value->isa<StringImm>()) {
162 MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
163 }
164 auto key_input = GetValue<std::string>(key_value);
165 std::string key_actual = kwarg->get_key();
166 if (key_actual != key_input) {
167 MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
168 << key_input << ", AbstractKeywordArg' key is " << key_actual;
169 }
170 } else if (args_abs_list.size() == only_kw_input_size) {
171 kwarg = CheckArg<AbstractKeywordArg>(op_name, args_abs_list, 0);
172 } else {
173 MS_LOG(EXCEPTION) << "For '" << op_name << "', the number of inputs should be 1 or 2, but got "
174 << args_abs_list.size();
175 }
176 return kwarg->get_arg();
177 }
178
CheckDynamicLengthSequenceSetItem(const std::string & op_name,const AbstractSequencePtr & queue,const AbstractBasePtr & target)179 void CheckDynamicLengthSequenceSetItem(const std::string &op_name, const AbstractSequencePtr &queue,
180 const AbstractBasePtr &target) {
181 auto element_abs = queue->dynamic_len_element_abs();
182 if (element_abs == nullptr) {
183 MS_LOG(EXCEPTION) << "Empty variable len sequence can not setitem.";
184 }
185 const auto precondition_log = "For " + op_name + ", when the queue is dynamic length";
186 const auto standard_abs_description = "element within dynamic length sequence";
187 const auto differ_abs_description = "target element";
188 CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{element_abs, target},
189 precondition_log, standard_abs_description,
190 differ_abs_description);
191 }
192
193 template <typename T>
InferTupleOrListSetItem(const std::string & op_name,const AbstractBasePtrList & args_abs_list)194 AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_abs_list) {
195 // Inputs: a tuple or list, a scalar whose value is an int64 number and an object of a subclass of AbstractBase.
196 constexpr int args_spec_size = 3;
197 CheckArgsSize(op_name, args_abs_list, args_spec_size);
198 auto queue = CheckArg<T>(op_name, args_abs_list, 0);
199 AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_abs_list, 1);
200
201 auto index_type = index->BuildType();
202 MS_EXCEPTION_IF_NULL(index_type);
203 if (index_type->type_id() != kInt64->type_id()) {
204 MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got a "
205 << index_type->ToString() << " number.";
206 }
207 ValuePtr index_value = index->BuildValue();
208 MS_EXCEPTION_IF_NULL(index_value);
209 auto target = args_abs_list[kIndex2];
210 MS_EXCEPTION_IF_NULL(target);
211 if (queue->dynamic_len()) {
212 CheckDynamicLengthSequenceSetItem(op_name, queue, target);
213 return queue->Clone();
214 }
215 if (index_value->ContainsValueAny()) {
216 // If the index is variable and the sequence is constant length, then all of the element within the sequence
217 // should have the same type and shape with the target input. The element within the return sequence should
218 // be all broadened.
219 const auto &elements = queue->elements();
220 if (elements.size() == 0) {
221 MS_LOG(EXCEPTION) << "Empty sequence can not setitem.";
222 }
223 const auto precondition_log = "For " + op_name + ", when the index is variable and the queue is constant length";
224 CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(elements, precondition_log);
225 auto first_element = elements[kIndex0];
226 const auto standard_abs_description = "element within constant length sequence";
227 const auto differ_abs_description = "target element";
228 CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{first_element, target},
229 precondition_log, standard_abs_description,
230 differ_abs_description);
231 return CheckAndConvertUtils::BroadenAllSequenceElements(queue);
232 }
233 auto index_int64_value = GetValue<int64_t>(index_value);
234 AbstractBasePtrList elements = queue->elements();
235 std::size_t nelems = elements.size();
236 if (nelems == 0) {
237 MS_EXCEPTION(IndexError) << "Can not setitem for an empty sequence.";
238 }
239 int64_t index_positive_value = index_int64_value >= 0 ? index_int64_value : index_int64_value + SizeToLong(nelems);
240 if (index_positive_value < 0 || index_positive_value >= SizeToLong(nelems)) {
241 MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << index_int64_value << " to set out of range: [-"
242 << nelems << "," << (nelems - 1) << "].";
243 }
244 size_t index_unsigned_value = LongToSize(index_positive_value);
245 elements[index_unsigned_value] = args_abs_list[kIndex2];
246 MS_LOG(DEBUG) << "SetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString();
247 return std::make_shared<T>(elements, queue->sequence_nodes());
248 }
249
InferImplTupleSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)250 AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
251 const AbstractBasePtrList &args_abs_list) {
252 return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_abs_list);
253 }
254
InferImplListSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)255 AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
256 const AbstractBasePtrList &args_abs_list) {
257 return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_abs_list);
258 }
259
InferImplDictGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)260 AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
261 const AbstractBasePtrList &args_abs_list) {
262 const std::string op_name = primitive->name();
263 // dict[key] mean the size of args_abs_list is 2.
264 // dict.get('key', default_value=None) mean the size of args_abs_list is 2 too, the key will check in dict_get.
265 constexpr int subscript_args_size = 2;
266 if (args_abs_list.size() != subscript_args_size) {
267 MS_LOG(EXCEPTION) << "For '" << op_name << "', the number of input should be " << subscript_args_size
268 << ", but got " << args_abs_list.size();
269 }
270 AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
271 const auto &key = args_abs_list[1];
272 CheckDictKey(key, op_name);
273
274 ValuePtr key_value = key->BuildValue();
275 MS_EXCEPTION_IF_NULL(key_value);
276 std::vector<AbstractElementPair> dict_elems = dict->elements();
277 auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
278 return *key_value == *item.first->BuildValue();
279 });
280 if (it == dict_elems.end()) {
281 // For dict[key], if key is not exist, will raise a ValueError exception.
282 // For dict.get('key', default=None), if key is not exist, will return the default value during dict_get.
283 // Python KeyError will print escape character. So use ValueError instead of KeyError here.
284 MS_EXCEPTION(ValueError) << "The key " << key_value->ToString()
285 << " does not exist in the dict:" << args_abs_list[0]->BuildValue()->ToString();
286 }
287 return it->second;
288 }
289
InferImplDictSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)290 AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
291 const AbstractBasePtrList &args_abs_list) {
292 // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
293 const std::string op_name = primitive->name();
294 constexpr int args_spec_size = 3;
295 CheckArgsSize(op_name, args_abs_list, args_spec_size);
296 AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
297 const auto &key = args_abs_list[1];
298 CheckDictKey(key, op_name);
299
300 ValuePtr key_value = key->BuildValue();
301 MS_EXCEPTION_IF_NULL(key_value);
302 std::vector<AbstractElementPair> dict_elems = dict->elements();
303 auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
304 return *key_value == *item.first->BuildValue();
305 });
306
307 MS_EXCEPTION_IF_NULL(args_abs_list[2]);
308 auto new_ele = std::make_pair(args_abs_list[1], args_abs_list[2]);
309 if (it != dict_elems.end()) {
310 int64_t index = it - dict_elems.begin();
311 dict_elems[LongToSize(index)] = new_ele;
312 } else {
313 dict_elems.push_back(new_ele);
314 }
315 return std::make_shared<AbstractDictionary>(dict_elems);
316 }
317
InferImplDictGetKeys(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)318 AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
319 const AbstractBasePtrList &args_abs_list) {
320 // Inputs: a dict.
321 const std::string op_name = primitive->name();
322 constexpr int args_spec_size = 1;
323 CheckArgsSize(op_name, args_abs_list, args_spec_size);
324 AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
325 std::vector<AbstractElementPair> dict_elems = dict->elements();
326 AbstractBasePtrList keys;
327 (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
328 [](const AbstractElementPair &item) { return item.first; });
329 return std::make_shared<AbstractTuple>(keys);
330 }
331
InferImplDictGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)332 AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
333 const AbstractBasePtrList &args_abs_list) {
334 // Inputs: a dict.
335 const std::string op_name = primitive->name();
336 constexpr int args_spec_size = 1;
337 CheckArgsSize(op_name, args_abs_list, args_spec_size);
338 AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
339 std::vector<AbstractElementPair> dict_elems = dict->elements();
340 AbstractBasePtrList values;
341 (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(values),
342 [](const AbstractElementPair &item) { return item.second; });
343 return std::make_shared<AbstractTuple>(values);
344 }
345
InferImplDictItems(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)346 AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
347 const AbstractBasePtrList &args_abs_list) {
348 // Inputs: a dict.
349 const std::string op_name = primitive->name();
350 constexpr int args_spec_size = 1;
351 CheckArgsSize(op_name, args_abs_list, args_spec_size);
352 AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
353 std::vector<AbstractElementPair> dict_elems = dict->elements();
354 AbstractBasePtrList items;
355 (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(items),
356 [](const AbstractElementPair &item) {
357 return std::make_shared<AbstractTuple>(AbstractBasePtrList{item.first, item.second});
358 });
359 return std::make_shared<AbstractList>(items);
360 }
361
362 namespace {
CheckMutableArgAbstract(const AbstractBasePtr & abs)363 void CheckMutableArgAbstract(const AbstractBasePtr &abs) {
364 if (abs->isa<AbstractSequence>()) {
365 auto abs_seq = abs->cast_ptr<AbstractSequence>();
366 for (const auto &ele : abs_seq->elements()) {
367 CheckMutableArgAbstract(ele);
368 }
369 return;
370 }
371 if (abs->isa<AbstractDictionary>()) {
372 auto abs_dic = abs->cast_ptr<AbstractDictionary>();
373 for (const auto &ele : abs_dic->elements()) {
374 CheckMutableArgAbstract(ele.second);
375 }
376 return;
377 }
378 if (abs->isa<AbstractTensor>()) {
379 return;
380 }
381 if (abs->isa<AbstractScalar>()) {
382 auto type_ptr = abs->GetType();
383 if (type_ptr->isa<Number>()) {
384 return;
385 }
386 }
387 MS_EXCEPTION(TypeError) << "For 'mutable', the 'input_data' should be one of (bool, int, float, Tensor, "
388 "tuple, list, dict) or their nested structures, but got "
389 << abs->ToString();
390 }
391 } // namespace
392
InferImplMutable(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)393 AbstractBasePtr InferImplMutable(const AnalysisEnginePtr &, const PrimitivePtr &,
394 const AbstractBasePtrList &args_abs_list) {
395 constexpr int min_args_abs_size = 1;
396 constexpr int max_args_abs_size = 2;
397 auto arg_size = args_abs_list.size();
398 if (arg_size != min_args_abs_size && arg_size != max_args_abs_size) {
399 MS_LOG(EXCEPTION) << "For 'mutable', the number of inputs should be 1 or 2, but got " << args_abs_list.size();
400 }
401 bool variable_len = false;
402 if (arg_size == max_args_abs_size) {
403 auto arg_value = args_abs_list[1]->GetValue();
404 MS_EXCEPTION_IF_NULL(arg_value);
405 if (!arg_value->isa<BoolImm>()) {
406 MS_EXCEPTION(TypeError) << "For 'mutable', the second input should be bool, but got: "
407 << args_abs_list[1]->ToString();
408 }
409 variable_len = arg_value->cast<BoolImmPtr>()->value();
410 }
411 auto data = args_abs_list[0];
412 MS_EXCEPTION_IF_NULL(data);
413 if (!variable_len) {
414 if (data->isa<AbstractSequence>() && data->cast<AbstractSequencePtr>()->dynamic_len()) {
415 MS_LOG(EXCEPTION) << "For 'mutable', can not convert a dynamic length sequence to constant length.";
416 }
417 CheckMutableArgAbstract(data);
418 return AbstractBroaden(data);
419 }
420 auto ret = data->Clone();
421 if (ret->isa<AbstractAny>()) {
422 return ret;
423 }
424 if (!ret->isa<AbstractSequence>()) {
425 MS_EXCEPTION(TypeError) << "For 'mutable', when the variable_len is True, the first input should be"
426 << " list or tuple, but got: " << ret->ToString();
427 }
428 auto ret_seq = ret->cast<AbstractSequencePtr>();
429 if (!ret_seq->dynamic_len()) {
430 ret_seq->CheckAndConvertToDynamicLenSequence();
431 }
432 if (ret->isa<AbstractList>()) {
433 // Dynamic length list should not attach python object.
434 auto ret_list = ret->cast<AbstractListPtr>();
435 ret_list->ClearExtraInfo();
436 }
437 return ret;
438 }
439
440 namespace {
GetRefKey(const AbstractRefPtr & ref_tensor)441 std::string GetRefKey(const AbstractRefPtr &ref_tensor) {
442 const auto &ref_key_value = ref_tensor->ref_key_value();
443 MS_EXCEPTION_IF_NULL(ref_key_value);
444 auto ref_key = ref_key_value->cast_ptr<RefKey>();
445 MS_EXCEPTION_IF_NULL(ref_key);
446 return ref_key->value();
447 }
448
GetGradAbstract(const AbstractBasePtr & grads_abs,const std::string & para_name,int64_t position,AbstractBasePtr * ret)449 void GetGradAbstract(const AbstractBasePtr &grads_abs, const std::string ¶_name, int64_t position,
450 AbstractBasePtr *ret) {
451 auto grad_abs_tuple = grads_abs->cast_ptr<AbstractTuple>();
452 if (grad_abs_tuple == nullptr || grad_abs_tuple->elements().size() == 0) {
453 return;
454 }
455 auto abs0 = grad_abs_tuple->elements()[0];
456 if (abs0->isa<AbstractScalar>()) {
457 auto buildptr = abs0->cast_ptr<AbstractScalar>();
458 MS_EXCEPTION_IF_NULL(buildptr);
459 auto build_value = buildptr->BuildValue();
460 size_t expect_size = 2;
461 if (grad_abs_tuple->elements().size() >= expect_size) {
462 if (build_value->isa<Int64Imm>()) {
463 if (GetValue<int64_t>(build_value) == position) {
464 *ret = grad_abs_tuple->elements()[1];
465 }
466 } else if (build_value->isa<StringImm>()) {
467 if (GetValue<std::string>(build_value) == para_name) {
468 *ret = grad_abs_tuple->elements()[1];
469 }
470 }
471 }
472 return;
473 } else {
474 for (const auto &abs : grad_abs_tuple->elements()) {
475 GetGradAbstract(abs, para_name, position, ret);
476 }
477 return;
478 }
479 }
480 } // namespace
481
InferImplGetGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)482 AbstractBasePtr InferImplGetGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
483 const AbstractBasePtrList &args_abs_list) {
484 const std::string &op_name = primitive->name();
485 constexpr int expected_args_spec_size = 2;
486 CheckArgsSize(op_name, args_abs_list, expected_args_spec_size);
487 auto &hash_id_abs = args_abs_list[1];
488
489 int64_t position = -1;
490 std::string para_name;
491 if (hash_id_abs->isa<AbstractScalar>()) {
492 auto buildptr = hash_id_abs->cast_ptr<AbstractScalar>();
493 if (buildptr == nullptr) {
494 MS_EXCEPTION(TypeError) << "For " << op_name << ", the `x` should be an integer or a Parameter, but got nullptr";
495 }
496 auto build_value = buildptr->BuildValue();
497 if (!build_value->isa<Int64Imm>()) {
498 MS_EXCEPTION(TypeError) << "For " << op_name << ", the `x` should be an int64 number, but got "
499 << build_value->ToString();
500 }
501 position = GetValue<int64_t>(build_value);
502 } else if (hash_id_abs->isa<AbstractRefTensor>()) {
503 para_name = GetRefKey(hash_id_abs->cast<AbstractRefPtr>());
504 } else {
505 MS_EXCEPTION(TypeError) << "For " << op_name << ", the `x` should be an integer or a Parameter, but got "
506 << hash_id_abs->ToString();
507 }
508 AbstractBasePtr ret = nullptr;
509 GetGradAbstract(args_abs_list[0], para_name, position, &ret);
510 if (ret == nullptr) {
511 MS_LOG(EXCEPTION) << "Can not find the gradient for position or Parameter " << args_abs_list[1]->ToString();
512 }
513 return ret;
514 }
515 } // namespace abstract
516 } // namespace mindspore
517