• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "abstract/ops/infer_functions.h"
20 #include "abstract/utils.h"
21 #include "abstract/param_validator.h"
22 #include "utils/check_convert_utils.h"
23 #include "include/common/utils/utils.h"
24 
25 namespace mindspore {
26 namespace abstract {
27 namespace {
CheckDictKey(const AbstractBasePtr & key,const std::string & op_name)28 void CheckDictKey(const AbstractBasePtr &key, const std::string &op_name) {
29   auto key_value = key->BuildValue();
30   MS_EXCEPTION_IF_NULL(key_value);
31   if (!(key_value->isa<StringImm>() || key_value->isa<Scalar>() || key_value->isa<Type>() || key_value->isa<None>() ||
32         (key->isa<AbstractTensor>() && !key_value->ContainsValueAny()) || key->isa<AbstractTuple>())) {
33     MS_LOG(EXCEPTION) << op_name << " evaluator key only supports string, number, type, none, "
34                       << "constant tensor and tuple, but got " << key->BuildValue()->ToString();
35   }
36   if (key->isa<AbstractTuple>() && key_value->isa<ValueAny>()) {
37     MS_LOG(EXCEPTION) << op_name << " evaluator key should not be tuple that contains variables, but got "
38                       << key->BuildValue()->ToString();
39   }
40 }
41 }  // namespace
42 
ProcessUnpackDict(const AbstractTuplePtr & key_tuple,const AbstractTuplePtr & value_tuple,std::unordered_map<std::string,AbstractBasePtr> * key_str_value_set,std::vector<AbstractBasePtr> * key_set)43 void ProcessUnpackDict(const AbstractTuplePtr &key_tuple, const AbstractTuplePtr &value_tuple,
44                        std::unordered_map<std::string, AbstractBasePtr> *key_str_value_set,
45                        std::vector<AbstractBasePtr> *key_set) {
46   // The size of need unpack tuple must be 1
47   const auto &key_elements = key_tuple->elements();
48   const auto &value_elements = value_tuple->elements();
49   if (key_elements.size() != 1) {
50     MS_LOG(EXCEPTION) << "The size of need unpack key tuple must be 1, but got " << key_elements.size();
51   }
52   if (value_elements.size() != 1) {
53     MS_LOG(EXCEPTION) << "The size of need unpack value tuple must be 1, but got " << value_elements.size();
54   }
55 
56   auto unpack_keys = key_elements[0];
57   auto unpack_keys_tuple = unpack_keys->cast<AbstractTuplePtr>();
58   const auto &unpack_keys_elements = unpack_keys_tuple->elements();
59 
60   auto unpack_values = value_elements[0];
61   auto unpack_values_tuple = unpack_values->cast<AbstractTuplePtr>();
62   const auto &unpack_values_elements = unpack_values_tuple->elements();
63 
64   if (unpack_keys_elements.size() != unpack_values_elements.size()) {
65     MS_LOG(EXCEPTION) << "The keys' size should be equal to values' size, but the keys' size is "
66                       << unpack_keys_elements.size() << ", the values' size is " << unpack_values_elements.size();
67   }
68 
69   for (size_t inner_index = 0; inner_index < unpack_keys_elements.size(); ++inner_index) {
70     auto inner_key = unpack_keys_elements[inner_index];
71     auto key_str = inner_key->BuildValue()->ToString();
72     (void)key_str_value_set->emplace(key_str, unpack_values_elements[inner_index]);
73     (void)key_set->emplace_back(inner_key);
74   }
75 }
76 
InferImplMakeDict(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)77 AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
78                                   const AbstractBasePtrList &args_abs_list) {
79   // Inputs: two tuples.
80   const std::string op_name = primitive->name();
81   constexpr int args_spec_size = 2;
82   CheckArgsSize(op_name, args_abs_list, args_spec_size);
83   AbstractSequencePtr keys = CheckArg<AbstractSequence>(op_name, args_abs_list, 0);
84   AbstractSequencePtr values = CheckArg<AbstractSequence>(op_name, args_abs_list, 1);
85 
86   size_t keys_size = keys->size();
87   if (values->size() != keys_size) {
88     MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
89   }
90 
91   AbstractBasePtrList key_list = keys->elements();
92   std::unordered_map<std::string, AbstractBasePtr> key_str_value_set;
93   std::vector<AbstractBasePtr> key_set;
94   std::vector<AbstractElementPair> key_value;
95   AbstractBasePtrList value_list = values->elements();
96   constexpr auto need_unpack = "need_unpack";
97   for (size_t index = 0; index < keys_size; index++) {
98     const auto &key = key_list[index];
99     bool is_need_unpack = false;
100     if (key->isa<AbstractTuple>()) {
101       auto key_tuple = key->cast<AbstractTuplePtr>();
102       if (key_tuple->HasData(need_unpack)) {
103         is_need_unpack = *key_tuple->GetData<bool>(need_unpack);
104         if (is_need_unpack) {
105           auto value_tuple = value_list[index]->cast<AbstractTuplePtr>();
106           MS_EXCEPTION_IF_NULL(value_tuple);
107           ProcessUnpackDict(key_tuple, value_tuple, &key_str_value_set, &key_set);
108         }
109       }
110     }
111     CheckDictKey(key, op_name);
112     auto key_val = key->BuildValue()->ToString();
113     auto iter = key_str_value_set.find(key_val);
114     // Remove duplicate keys.
115     // {Tensor[1]: x, Tensor[1}: y} the length of dict is 2, means the two keys are not duplicate.
116     if (iter != key_str_value_set.end() && !key->isa<AbstractTensor>()) {
117       iter->second = value_list[index];
118     } else if (!is_need_unpack) {
119       auto key_str = key->BuildValue()->ToString();
120       key_str_value_set.insert(std::pair<std::string, AbstractBasePtr>(key_str, value_list[index]));
121       key_set.push_back(key);
122     }
123   }
124   for (auto &key : key_set) {
125     auto key_str = key->BuildValue()->ToString();
126     (void)key_value.emplace_back(key, key_str_value_set[key_str]);
127   }
128   return std::make_shared<AbstractDictionary>(key_value);
129 }
130 
InferImplMakeKeywordArg(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)131 AbstractBasePtr InferImplMakeKeywordArg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
132                                         const AbstractBasePtrList &args_abs_list) {
133   // Inputs: a string and an object of a subclass of AbstractBase.
134   const std::string op_name = primitive->name();
135   constexpr int args_spec_size = 2;
136   CheckArgsSize(op_name, args_abs_list, args_spec_size);
137   AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_abs_list, 0);
138 
139   ValuePtr keyPtr = key->BuildValue();
140   MS_EXCEPTION_IF_NULL(keyPtr);
141   if (!keyPtr->isa<StringImm>()) {
142     MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
143   }
144   auto key_string = GetValue<std::string>(keyPtr);
145   return std::make_shared<AbstractKeywordArg>(key_string, args_abs_list[1]);
146 }
147 
InferImplExtractKeywordArg(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)148 AbstractBasePtr InferImplExtractKeywordArg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
149                                            const AbstractBasePtrList &args_abs_list) {
150   // Inputs: a key and a Keyword or only a Keyword.
151   const std::string op_name = primitive->name();
152   constexpr int only_kw_input_size = 1;
153   constexpr int check_key_input_size = 2;
154   AbstractKeywordArgPtr kwarg = nullptr;
155   if (args_abs_list.size() == check_key_input_size) {
156     AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_abs_list, 0);
157     kwarg = CheckArg<AbstractKeywordArg>(op_name, args_abs_list, 1);
158 
159     ValuePtr key_value = key->BuildValue();
160     MS_EXCEPTION_IF_NULL(key_value);
161     if (!key_value->isa<StringImm>()) {
162       MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
163     }
164     auto key_input = GetValue<std::string>(key_value);
165     std::string key_actual = kwarg->get_key();
166     if (key_actual != key_input) {
167       MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
168                         << key_input << ", AbstractKeywordArg' key is " << key_actual;
169     }
170   } else if (args_abs_list.size() == only_kw_input_size) {
171     kwarg = CheckArg<AbstractKeywordArg>(op_name, args_abs_list, 0);
172   } else {
173     MS_LOG(EXCEPTION) << "For '" << op_name << "', the number of inputs should be 1 or 2, but got "
174                       << args_abs_list.size();
175   }
176   return kwarg->get_arg();
177 }
178 
CheckDynamicLengthSequenceSetItem(const std::string & op_name,const AbstractSequencePtr & queue,const AbstractBasePtr & target)179 void CheckDynamicLengthSequenceSetItem(const std::string &op_name, const AbstractSequencePtr &queue,
180                                        const AbstractBasePtr &target) {
181   auto element_abs = queue->dynamic_len_element_abs();
182   if (element_abs == nullptr) {
183     MS_LOG(EXCEPTION) << "Empty variable len sequence can not setitem.";
184   }
185   const auto precondition_log = "For " + op_name + ", when the queue is dynamic length";
186   const auto standard_abs_description = "element within dynamic length sequence";
187   const auto differ_abs_description = "target element";
188   CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{element_abs, target},
189                                                       precondition_log, standard_abs_description,
190                                                       differ_abs_description);
191 }
192 
193 template <typename T>
InferTupleOrListSetItem(const std::string & op_name,const AbstractBasePtrList & args_abs_list)194 AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_abs_list) {
195   // Inputs: a tuple or list, a scalar whose value is an int64 number and an object of a subclass of AbstractBase.
196   constexpr int args_spec_size = 3;
197   CheckArgsSize(op_name, args_abs_list, args_spec_size);
198   auto queue = CheckArg<T>(op_name, args_abs_list, 0);
199   AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_abs_list, 1);
200 
201   auto index_type = index->BuildType();
202   MS_EXCEPTION_IF_NULL(index_type);
203   if (index_type->type_id() != kInt64->type_id()) {
204     MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got a "
205                              << index_type->ToString() << " number.";
206   }
207   ValuePtr index_value = index->BuildValue();
208   MS_EXCEPTION_IF_NULL(index_value);
209   auto target = args_abs_list[kIndex2];
210   MS_EXCEPTION_IF_NULL(target);
211   if (queue->dynamic_len()) {
212     CheckDynamicLengthSequenceSetItem(op_name, queue, target);
213     return queue->Clone();
214   }
215   if (index_value->ContainsValueAny()) {
216     // If the index is variable and the sequence is constant length, then all of the element within the sequence
217     // should have the same type and shape with the target input. The element within the return sequence should
218     // be all broadened.
219     const auto &elements = queue->elements();
220     if (elements.size() == 0) {
221       MS_LOG(EXCEPTION) << "Empty sequence can not setitem.";
222     }
223     const auto precondition_log = "For " + op_name + ", when the index is variable and the queue is constant length";
224     CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(elements, precondition_log);
225     auto first_element = elements[kIndex0];
226     const auto standard_abs_description = "element within constant length sequence";
227     const auto differ_abs_description = "target element";
228     CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{first_element, target},
229                                                         precondition_log, standard_abs_description,
230                                                         differ_abs_description);
231     return CheckAndConvertUtils::BroadenAllSequenceElements(queue);
232   }
233   auto index_int64_value = GetValue<int64_t>(index_value);
234   AbstractBasePtrList elements = queue->elements();
235   std::size_t nelems = elements.size();
236   if (nelems == 0) {
237     MS_EXCEPTION(IndexError) << "Can not setitem for an empty sequence.";
238   }
239   int64_t index_positive_value = index_int64_value >= 0 ? index_int64_value : index_int64_value + SizeToLong(nelems);
240   if (index_positive_value < 0 || index_positive_value >= SizeToLong(nelems)) {
241     MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << index_int64_value << " to set out of range: [-"
242                              << nelems << "," << (nelems - 1) << "].";
243   }
244   size_t index_unsigned_value = LongToSize(index_positive_value);
245   elements[index_unsigned_value] = args_abs_list[kIndex2];
246   MS_LOG(DEBUG) << "SetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString();
247   return std::make_shared<T>(elements, queue->sequence_nodes());
248 }
249 
InferImplTupleSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)250 AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
251                                       const AbstractBasePtrList &args_abs_list) {
252   return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_abs_list);
253 }
254 
InferImplListSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)255 AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
256                                      const AbstractBasePtrList &args_abs_list) {
257   return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_abs_list);
258 }
259 
InferImplDictGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)260 AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
261                                      const AbstractBasePtrList &args_abs_list) {
262   const std::string op_name = primitive->name();
263   // dict[key] mean the size of args_abs_list is 2.
264   // dict.get('key', default_value=None) mean the size of args_abs_list is 2 too, the key will check in dict_get.
265   constexpr int subscript_args_size = 2;
266   if (args_abs_list.size() != subscript_args_size) {
267     MS_LOG(EXCEPTION) << "For '" << op_name << "', the number of input should be " << subscript_args_size
268                       << ", but got " << args_abs_list.size();
269   }
270   AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
271   const auto &key = args_abs_list[1];
272   CheckDictKey(key, op_name);
273 
274   ValuePtr key_value = key->BuildValue();
275   MS_EXCEPTION_IF_NULL(key_value);
276   std::vector<AbstractElementPair> dict_elems = dict->elements();
277   auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
278     return *key_value == *item.first->BuildValue();
279   });
280   if (it == dict_elems.end()) {
281     // For dict[key], if key is not exist, will raise a ValueError exception.
282     // For dict.get('key', default=None), if key is not exist, will return the default value during dict_get.
283     // Python KeyError will print escape character. So use ValueError instead of KeyError here.
284     MS_EXCEPTION(ValueError) << "The key " << key_value->ToString()
285                              << " does not exist in the dict:" << args_abs_list[0]->BuildValue()->ToString();
286   }
287   return it->second;
288 }
289 
InferImplDictSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)290 AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
291                                      const AbstractBasePtrList &args_abs_list) {
292   // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
293   const std::string op_name = primitive->name();
294   constexpr int args_spec_size = 3;
295   CheckArgsSize(op_name, args_abs_list, args_spec_size);
296   AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
297   const auto &key = args_abs_list[1];
298   CheckDictKey(key, op_name);
299 
300   ValuePtr key_value = key->BuildValue();
301   MS_EXCEPTION_IF_NULL(key_value);
302   std::vector<AbstractElementPair> dict_elems = dict->elements();
303   auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
304     return *key_value == *item.first->BuildValue();
305   });
306 
307   MS_EXCEPTION_IF_NULL(args_abs_list[2]);
308   auto new_ele = std::make_pair(args_abs_list[1], args_abs_list[2]);
309   if (it != dict_elems.end()) {
310     int64_t index = it - dict_elems.begin();
311     dict_elems[LongToSize(index)] = new_ele;
312   } else {
313     dict_elems.push_back(new_ele);
314   }
315   return std::make_shared<AbstractDictionary>(dict_elems);
316 }
317 
InferImplDictGetKeys(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)318 AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
319                                      const AbstractBasePtrList &args_abs_list) {
320   // Inputs: a dict.
321   const std::string op_name = primitive->name();
322   constexpr int args_spec_size = 1;
323   CheckArgsSize(op_name, args_abs_list, args_spec_size);
324   AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
325   std::vector<AbstractElementPair> dict_elems = dict->elements();
326   AbstractBasePtrList keys;
327   (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
328                        [](const AbstractElementPair &item) { return item.first; });
329   return std::make_shared<AbstractTuple>(keys);
330 }
331 
InferImplDictGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)332 AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
333                                        const AbstractBasePtrList &args_abs_list) {
334   // Inputs: a dict.
335   const std::string op_name = primitive->name();
336   constexpr int args_spec_size = 1;
337   CheckArgsSize(op_name, args_abs_list, args_spec_size);
338   AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
339   std::vector<AbstractElementPair> dict_elems = dict->elements();
340   AbstractBasePtrList values;
341   (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(values),
342                        [](const AbstractElementPair &item) { return item.second; });
343   return std::make_shared<AbstractTuple>(values);
344 }
345 
InferImplDictItems(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)346 AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
347                                    const AbstractBasePtrList &args_abs_list) {
348   // Inputs: a dict.
349   const std::string op_name = primitive->name();
350   constexpr int args_spec_size = 1;
351   CheckArgsSize(op_name, args_abs_list, args_spec_size);
352   AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_abs_list, 0);
353   std::vector<AbstractElementPair> dict_elems = dict->elements();
354   AbstractBasePtrList items;
355   (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(items),
356                        [](const AbstractElementPair &item) {
357                          return std::make_shared<AbstractTuple>(AbstractBasePtrList{item.first, item.second});
358                        });
359   return std::make_shared<AbstractList>(items);
360 }
361 
362 namespace {
CheckMutableArgAbstract(const AbstractBasePtr & abs)363 void CheckMutableArgAbstract(const AbstractBasePtr &abs) {
364   if (abs->isa<AbstractSequence>()) {
365     auto abs_seq = abs->cast_ptr<AbstractSequence>();
366     for (const auto &ele : abs_seq->elements()) {
367       CheckMutableArgAbstract(ele);
368     }
369     return;
370   }
371   if (abs->isa<AbstractDictionary>()) {
372     auto abs_dic = abs->cast_ptr<AbstractDictionary>();
373     for (const auto &ele : abs_dic->elements()) {
374       CheckMutableArgAbstract(ele.second);
375     }
376     return;
377   }
378   if (abs->isa<AbstractTensor>()) {
379     return;
380   }
381   if (abs->isa<AbstractScalar>()) {
382     auto type_ptr = abs->GetType();
383     if (type_ptr->isa<Number>()) {
384       return;
385     }
386   }
387   MS_EXCEPTION(TypeError) << "For 'mutable', the 'input_data' should be one of (bool, int, float, Tensor, "
388                              "tuple, list, dict) or their nested structures, but got "
389                           << abs->ToString();
390 }
391 }  // namespace
392 
InferImplMutable(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)393 AbstractBasePtr InferImplMutable(const AnalysisEnginePtr &, const PrimitivePtr &,
394                                  const AbstractBasePtrList &args_abs_list) {
395   constexpr int min_args_abs_size = 1;
396   constexpr int max_args_abs_size = 2;
397   auto arg_size = args_abs_list.size();
398   if (arg_size != min_args_abs_size && arg_size != max_args_abs_size) {
399     MS_LOG(EXCEPTION) << "For 'mutable', the number of inputs should be 1 or 2, but got " << args_abs_list.size();
400   }
401   bool variable_len = false;
402   if (arg_size == max_args_abs_size) {
403     auto arg_value = args_abs_list[1]->GetValue();
404     MS_EXCEPTION_IF_NULL(arg_value);
405     if (!arg_value->isa<BoolImm>()) {
406       MS_EXCEPTION(TypeError) << "For 'mutable', the second input should be bool, but got: "
407                               << args_abs_list[1]->ToString();
408     }
409     variable_len = arg_value->cast<BoolImmPtr>()->value();
410   }
411   auto data = args_abs_list[0];
412   MS_EXCEPTION_IF_NULL(data);
413   if (!variable_len) {
414     if (data->isa<AbstractSequence>() && data->cast<AbstractSequencePtr>()->dynamic_len()) {
415       MS_LOG(EXCEPTION) << "For 'mutable', can not convert a dynamic length sequence to constant length.";
416     }
417     CheckMutableArgAbstract(data);
418     return AbstractBroaden(data);
419   }
420   auto ret = data->Clone();
421   if (ret->isa<AbstractAny>()) {
422     return ret;
423   }
424   if (!ret->isa<AbstractSequence>()) {
425     MS_EXCEPTION(TypeError) << "For 'mutable', when the variable_len is True, the first input should be"
426                             << " list or tuple, but got: " << ret->ToString();
427   }
428   auto ret_seq = ret->cast<AbstractSequencePtr>();
429   if (!ret_seq->dynamic_len()) {
430     ret_seq->CheckAndConvertToDynamicLenSequence();
431   }
432   if (ret->isa<AbstractList>()) {
433     // Dynamic length list should not attach python object.
434     auto ret_list = ret->cast<AbstractListPtr>();
435     ret_list->ClearExtraInfo();
436   }
437   return ret;
438 }
439 
440 namespace {
GetRefKey(const AbstractRefPtr & ref_tensor)441 std::string GetRefKey(const AbstractRefPtr &ref_tensor) {
442   const auto &ref_key_value = ref_tensor->ref_key_value();
443   MS_EXCEPTION_IF_NULL(ref_key_value);
444   auto ref_key = ref_key_value->cast_ptr<RefKey>();
445   MS_EXCEPTION_IF_NULL(ref_key);
446   return ref_key->value();
447 }
448 
GetGradAbstract(const AbstractBasePtr & grads_abs,const std::string & para_name,int64_t position,AbstractBasePtr * ret)449 void GetGradAbstract(const AbstractBasePtr &grads_abs, const std::string &para_name, int64_t position,
450                      AbstractBasePtr *ret) {
451   auto grad_abs_tuple = grads_abs->cast_ptr<AbstractTuple>();
452   if (grad_abs_tuple == nullptr || grad_abs_tuple->elements().size() == 0) {
453     return;
454   }
455   auto abs0 = grad_abs_tuple->elements()[0];
456   if (abs0->isa<AbstractScalar>()) {
457     auto buildptr = abs0->cast_ptr<AbstractScalar>();
458     MS_EXCEPTION_IF_NULL(buildptr);
459     auto build_value = buildptr->BuildValue();
460     size_t expect_size = 2;
461     if (grad_abs_tuple->elements().size() >= expect_size) {
462       if (build_value->isa<Int64Imm>()) {
463         if (GetValue<int64_t>(build_value) == position) {
464           *ret = grad_abs_tuple->elements()[1];
465         }
466       } else if (build_value->isa<StringImm>()) {
467         if (GetValue<std::string>(build_value) == para_name) {
468           *ret = grad_abs_tuple->elements()[1];
469         }
470       }
471     }
472     return;
473   } else {
474     for (const auto &abs : grad_abs_tuple->elements()) {
475       GetGradAbstract(abs, para_name, position, ret);
476     }
477     return;
478   }
479 }
480 }  // namespace
481 
InferImplGetGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)482 AbstractBasePtr InferImplGetGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
483                                  const AbstractBasePtrList &args_abs_list) {
484   const std::string &op_name = primitive->name();
485   constexpr int expected_args_spec_size = 2;
486   CheckArgsSize(op_name, args_abs_list, expected_args_spec_size);
487   auto &hash_id_abs = args_abs_list[1];
488 
489   int64_t position = -1;
490   std::string para_name;
491   if (hash_id_abs->isa<AbstractScalar>()) {
492     auto buildptr = hash_id_abs->cast_ptr<AbstractScalar>();
493     if (buildptr == nullptr) {
494       MS_EXCEPTION(TypeError) << "For " << op_name << ", the `x` should be an integer or a Parameter, but got nullptr";
495     }
496     auto build_value = buildptr->BuildValue();
497     if (!build_value->isa<Int64Imm>()) {
498       MS_EXCEPTION(TypeError) << "For " << op_name << ", the `x` should be an int64 number, but got "
499                               << build_value->ToString();
500     }
501     position = GetValue<int64_t>(build_value);
502   } else if (hash_id_abs->isa<AbstractRefTensor>()) {
503     para_name = GetRefKey(hash_id_abs->cast<AbstractRefPtr>());
504   } else {
505     MS_EXCEPTION(TypeError) << "For " << op_name << ", the `x` should be an integer or a Parameter, but got "
506                             << hash_id_abs->ToString();
507   }
508   AbstractBasePtr ret = nullptr;
509   GetGradAbstract(args_abs_list[0], para_name, position, &ret);
510   if (ret == nullptr) {
511     MS_LOG(EXCEPTION) << "Can not find the gradient for position or Parameter " << args_abs_list[1]->ToString();
512   }
513   return ret;
514 }
515 }  // namespace abstract
516 }  // namespace mindspore
517