• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/operator/composite/list_operation.h"
18 
19 #include <string>
20 
21 #include "abstract/param_validator.h"
22 #include "frontend/optimizer/opt.h"
23 #include "include/common/pybind_api/api_register.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "pipeline/jit/ps/fallback.h"
26 #include "utils/ms_context.h"
27 
28 namespace mindspore {
29 // namespace to support composite operators definition
30 namespace prim {
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)31 FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
32   constexpr size_t list_append_size_expect = 2;
33   abstract::CheckArgsSize("ListAppend", args_list, list_append_size_expect);
34 
35   AbstractBasePtr obj_arg = args_list[0];
36   abstract::AbstractListPtr arg0_list = dyn_cast<abstract::AbstractList>(obj_arg);
37   MS_EXCEPTION_IF_NULL(arg0_list);
38 
39   FuncGraphPtr ret = std::make_shared<FuncGraph>();
40   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
41   ret->debug_info()->set_name("append");
42   AnfNodePtr arg0_node = ret->add_parameter();
43   AnfNodePtr arg1_node = ret->add_parameter();
44 
45   std::vector<AnfNodePtr> elems;
46   elems.push_back(NewValueNode(prim::kPrimMakeList));
47   size_t arg0_length = arg0_list->size();
48   for (size_t i = 0; i < arg0_length; ++i) {
49     elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(SizeToLong(i))}));
50   }
51   elems.push_back(arg1_node);
52 
53   ret->set_output(ret->NewCNode(elems));
54   return ret;
55 }
56 
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)57 FuncGraphPtr ListInsert::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
58   const size_t list_insert_args_size = 3;
59   abstract::CheckArgsSize("ListInsert", args_list, list_insert_args_size);
60   AbstractBasePtr index_arg = args_list[0];
61   AbstractBasePtr obj_arg = args_list[1];
62 
63   abstract::AbstractListPtr arg0_list = dyn_cast<abstract::AbstractList>(index_arg);
64   MS_EXCEPTION_IF_NULL(arg0_list);
65   size_t list_len = arg0_list->size();
66   int64_t len = SizeToLong(list_len);
67   FuncGraphPtr ret = std::make_shared<FuncGraph>();
68   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
69   ret->debug_info()->set_name("insert");
70   AnfNodePtr arg0_node = ret->add_parameter();
71   AnfNodePtr insert_index_node = ret->add_parameter();
72   AnfNodePtr insert_obj_node = ret->add_parameter();
73   // List inplace operation do not support:
74   // 1. The python object of list is not found.
75   // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
76   if (fallback::EnableFallbackListDictInplace() && fallback::HasObjInExtraInfoHolder(arg0_list) &&
77       scope_name().find("VmapRule") == std::string::npos) {
78     MS_LOG(DEBUG) << "Enable inplace operation, convert list insert to InplaceListInsert ops.";
79     AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplaceInsert), arg0_node, insert_index_node,
80                                           insert_obj_node};
81     auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
82     list_inplace_node->set_has_side_effect_node(true);
83     ret->set_output(list_inplace_node);
84     ret->set_has_side_effect_node(true);
85     return ret;
86   }
87 
88   std::vector<AnfNodePtr> elems;
89   elems.push_back(NewValueNode(prim::kPrimMakeList));
90   auto obj_arg_value = obj_arg->BuildValue();
91   MS_EXCEPTION_IF_NULL(obj_arg_value);
92   if (!utils::isa<int64_t>(obj_arg_value)) {
93     MS_EXCEPTION(TypeError) << "Integer argument expected, but got " << obj_arg_value->type_name()
94                             << " type value: " << obj_arg_value->ToString();
95   }
96   int64_t index_value = GetValue<int64_t>(obj_arg_value);
97   int64_t insert_position = 0;
98   if (index_value >= len) {
99     insert_position = len;
100   } else if (index_value > 0 && index_value < len) {
101     insert_position = index_value;
102   } else if (index_value < 0 && index_value > -len) {
103     insert_position = len + index_value;
104   }
105   for (int64_t i = 0; i < insert_position; ++i) {
106     auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
107     elems.push_back(value);
108   }
109   elems.push_back(insert_obj_node);
110   for (int64_t i = insert_position; i < len; ++i) {
111     auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
112     elems.push_back(value);
113   }
114   auto out = ret->NewCNode(elems);
115   ret->set_output(out);
116   return ret;
117 }
118 
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)119 FuncGraphPtr ListPop::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
120   constexpr size_t list_pop_args_size = 2;
121   abstract::CheckArgsSize("ListPop", args_list, list_pop_args_size);
122   abstract::AbstractListPtr list_input = dyn_cast<abstract::AbstractList>(args_list[0]);
123   AbstractBasePtr pop_index = args_list[1];
124   MS_EXCEPTION_IF_NULL(list_input);
125   size_t list_len = list_input->size();
126   int64_t len = SizeToLong(list_len);
127   FuncGraphPtr ret = std::make_shared<FuncGraph>();
128   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
129   ret->debug_info()->set_name("pop");
130   AnfNodePtr arg0_node = ret->add_parameter();
131   AnfNodePtr pop_index_node = ret->add_parameter();
132 
133   std::vector<AnfNodePtr> elems;
134   elems.push_back(NewValueNode(prim::kPrimMakeList));
135   auto pop_index_value = pop_index->BuildValue();
136   if (!utils::isa<int64_t>(pop_index_value)) {
137     MS_EXCEPTION(TypeError) << "Integer argument expected, but got " << pop_index_value->type_name()
138                             << " type value: " << pop_index_value->ToString();
139   }
140   int64_t index_value = GetValue<int64_t>(pop_index_value);
141   if (index_value >= len || index_value < -1 * len) {
142     MS_EXCEPTION(IndexError) << "The pop index out of range.";
143   }
144   int64_t pop_position = (index_value >= 0) ? index_value : (len + index_value);
145 
146   // List inplace operation do not support:
147   // 1. The python object of list is not found.
148   // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
149   if (fallback::EnableFallbackListDictInplace() && fallback::HasObjInExtraInfoHolder(list_input) &&
150       scope_name().find("VmapRule") == std::string::npos) {
151     MS_LOG(DEBUG) << "Enable inplace operation, convert list pop to InplaceListPop ops.";
152     AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplacePop), arg0_node, pop_index_node};
153     auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
154     list_inplace_node->set_has_side_effect_node(true);
155     ret->set_output(list_inplace_node);
156     ret->set_has_side_effect_node(true);
157     return ret;
158   }
159 
160   for (int64_t i = 0; i < pop_position; ++i) {
161     auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
162     elems.push_back(value);
163   }
164   auto pop_node = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(pop_position)});
165   for (int64_t i = pop_position + 1; i < len; ++i) {
166     auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
167     elems.push_back(value);
168   }
169 
170   auto new_list = ret->NewCNode(elems);
171   auto out = ret->NewCNode({NewValueNode(prim::kPrimMakeTuple), new_list, pop_node});
172   ret->set_output(out);
173   return ret;
174 }
175 
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)176 FuncGraphPtr ListClear::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
177   abstract::CheckArgsSize("ListClear", args_list, 1);
178 
179   FuncGraphPtr ret = std::make_shared<FuncGraph>();
180   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
181   ret->debug_info()->set_name("clear");
182   (void)ret->add_parameter();
183 
184   auto empty_list = std::vector<ValuePtr>();
185   ret->set_output(NewValueNode(std::make_shared<ValueList>(empty_list)));
186   return ret;
187 }
188 
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)189 FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
190   constexpr size_t list_extend_args_size = 2;
191   abstract::CheckArgsSize("ListExtend", args_list, list_extend_args_size);
192 
193   FuncGraphPtr ret = std::make_shared<FuncGraph>();
194   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
195   ret->debug_info()->set_name("extend");
196 
197   constexpr size_t current_index = 0;
198   constexpr size_t extend_index = 1;
199   auto abs_current = args_list[current_index];
200   auto abs_extend = args_list[extend_index];
201 
202   std::vector<AnfNodePtr> elems;
203   elems.push_back(NewValueNode(prim::kPrimMakeList));
204   auto abs_current_list = dyn_cast<abstract::AbstractList>(abs_current);
205   MS_EXCEPTION_IF_NULL(abs_current_list);
206 
207   // List inplace operation do not support:
208   // 1. The python object of list is not found.
209   // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
210   if (fallback::EnableFallbackListDictInplace() && fallback::HasObjInExtraInfoHolder(abs_current_list) &&
211       scope_name().find("VmapRule") == std::string::npos) {
212     MS_LOG(DEBUG) << "Enable inplace operation, convert list extend to InplaceListExtend ops.";
213     AnfNodePtr arg0_node = ret->add_parameter();
214     AnfNodePtr arg1_node = ret->add_parameter();
215     AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplaceExtend), arg0_node, arg1_node};
216     auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
217     list_inplace_node->set_has_side_effect_node(true);
218     ret->set_output(list_inplace_node);
219     ret->set_has_side_effect_node(true);
220     return ret;
221   }
222 
223   AddNodeToElems(abs_current_list, ret, &elems);
224   AddNodeToElems(abs_extend, ret, &elems);
225 
226   auto out = ret->NewCNode(elems);
227   ret->set_output(out);
228   return ret;
229 }
230 
AddNodeToElems(const AbstractBasePtr & arg,const FuncGraphPtr & ret,std::vector<AnfNodePtr> * elems)231 void ListExtend::AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems) {
232   AnfNodePtr arg_node = ret->add_parameter();
233   if (arg->isa<abstract::AbstractList>()) {
234     auto arg_list = dyn_cast<abstract::AbstractList>(arg);
235     if (arg_list->dynamic_len()) {
236       MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length list.";
237     }
238     int64_t len = SizeToLong(arg_list->size());
239     for (int64_t i = 0; i < len; ++i) {
240       auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
241       elems->push_back(value);
242     }
243     return;
244   }
245   if (arg->isa<abstract::AbstractTuple>()) {
246     auto arg_tuple = dyn_cast<abstract::AbstractTuple>(arg);
247     if (arg_tuple->dynamic_len()) {
248       MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length tuple.";
249     }
250     int64_t len = SizeToLong(arg_tuple->size());
251     for (int64_t i = 0; i < len; ++i) {
252       auto value = ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), arg_node, NewValueNode(i)});
253       elems->push_back(value);
254     }
255     return;
256   }
257   if (arg->isa<abstract::AbstractTensor>()) {
258     auto abs_tensor = dyn_cast<abstract::AbstractTensor>(arg);
259     auto shape_ptr = abs_tensor->BuildShape();
260     MS_EXCEPTION_IF_NULL(shape_ptr);
261     auto tensor_shape = shape_ptr->cast<abstract::ShapePtr>();
262     MS_EXCEPTION_IF_NULL(tensor_shape);
263     auto shape = tensor_shape->shape();
264     if (shape.empty()) {
265       MS_LOG(EXCEPTION) << "ListExtend does not support scalar tensor.";
266     }
267     if (shape[0] < 0) {
268       MS_LOG(EXCEPTION) << "ListExtend does not support the tensor whose shapes has an uncertain 0th dimension.";
269     }
270     int64_t len = shape[0];
271 
272     std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
273     ValuePtr op = prim::GetPythonOps("getitem", module_name);
274     for (int64_t i = 0; i < len; ++i) {
275       auto value = ret->NewCNode({NewValueNode(op), arg_node, NewValueNode(i)});
276       elems->push_back(value);
277     }
278     return;
279   }
280   MS_LOG(EXCEPTION) << "ListExtend supports list, tuple and Tensor, but got " << arg->ToString();
281 }
282 
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)283 FuncGraphPtr ListReverse::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
284   abstract::CheckArgsSize("ListReverse", args_list, 1);
285   abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(args_list[0]);
286   MS_EXCEPTION_IF_NULL(arg_list);
287   int64_t arg_length = SizeToLong(arg_list->size());
288 
289   FuncGraphPtr ret = std::make_shared<FuncGraph>();
290   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
291   ret->debug_info()->set_name("reverse");
292   AnfNodePtr arg_node = ret->add_parameter();
293   // List inplace operation do not support:
294   // 1. The python object of list is not found.
295   // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
296   if (fallback::EnableFallbackListDictInplace() && fallback::HasObjInExtraInfoHolder(arg_list) &&
297       scope_name().find("VmapRule") == std::string::npos) {
298     MS_LOG(DEBUG) << "Enable inplace operation, convert list reverse to InplaceListReverse ops.";
299     AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplaceReverse), arg_node};
300     auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
301     list_inplace_node->set_has_side_effect_node(true);
302     ret->set_output(list_inplace_node);
303     ret->set_has_side_effect_node(true);
304     return ret;
305   }
306 
307   std::vector<AnfNodePtr> elems;
308   elems.push_back(NewValueNode(prim::kPrimMakeList));
309   for (int64_t i = arg_length - 1; i >= 0; --i) {
310     elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)}));
311   }
312 
313   ret->set_output(ret->NewCNode(elems));
314   return ret;
315 }
316 }  // namespace prim
317 }  // namespace mindspore
318