1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/operator/composite/list_operation.h"
18
19 #include <string>
20
21 #include "abstract/param_validator.h"
22 #include "frontend/optimizer/opt.h"
23 #include "include/common/pybind_api/api_register.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "pipeline/jit/ps/fallback.h"
26 #include "utils/ms_context.h"
27
28 namespace mindspore {
29 // namespace to support composite operators definition
30 namespace prim {
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)31 FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
32 constexpr size_t list_append_size_expect = 2;
33 abstract::CheckArgsSize("ListAppend", args_list, list_append_size_expect);
34
35 AbstractBasePtr obj_arg = args_list[0];
36 abstract::AbstractListPtr arg0_list = dyn_cast<abstract::AbstractList>(obj_arg);
37 MS_EXCEPTION_IF_NULL(arg0_list);
38
39 FuncGraphPtr ret = std::make_shared<FuncGraph>();
40 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
41 ret->debug_info()->set_name("append");
42 AnfNodePtr arg0_node = ret->add_parameter();
43 AnfNodePtr arg1_node = ret->add_parameter();
44
45 std::vector<AnfNodePtr> elems;
46 elems.push_back(NewValueNode(prim::kPrimMakeList));
47 size_t arg0_length = arg0_list->size();
48 for (size_t i = 0; i < arg0_length; ++i) {
49 elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(SizeToLong(i))}));
50 }
51 elems.push_back(arg1_node);
52
53 ret->set_output(ret->NewCNode(elems));
54 return ret;
55 }
56
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)57 FuncGraphPtr ListInsert::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
58 const size_t list_insert_args_size = 3;
59 abstract::CheckArgsSize("ListInsert", args_list, list_insert_args_size);
60 AbstractBasePtr index_arg = args_list[0];
61 AbstractBasePtr obj_arg = args_list[1];
62
63 abstract::AbstractListPtr arg0_list = dyn_cast<abstract::AbstractList>(index_arg);
64 MS_EXCEPTION_IF_NULL(arg0_list);
65 size_t list_len = arg0_list->size();
66 int64_t len = SizeToLong(list_len);
67 FuncGraphPtr ret = std::make_shared<FuncGraph>();
68 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
69 ret->debug_info()->set_name("insert");
70 AnfNodePtr arg0_node = ret->add_parameter();
71 AnfNodePtr insert_index_node = ret->add_parameter();
72 AnfNodePtr insert_obj_node = ret->add_parameter();
73 // List inplace operation do not support:
74 // 1. The python object of list is not found.
75 // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
76 if (fallback::EnableFallbackListDictInplace() && fallback::HasObjInExtraInfoHolder(arg0_list) &&
77 scope_name().find("VmapRule") == std::string::npos) {
78 MS_LOG(DEBUG) << "Enable inplace operation, convert list insert to InplaceListInsert ops.";
79 AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplaceInsert), arg0_node, insert_index_node,
80 insert_obj_node};
81 auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
82 list_inplace_node->set_has_side_effect_node(true);
83 ret->set_output(list_inplace_node);
84 ret->set_has_side_effect_node(true);
85 return ret;
86 }
87
88 std::vector<AnfNodePtr> elems;
89 elems.push_back(NewValueNode(prim::kPrimMakeList));
90 auto obj_arg_value = obj_arg->BuildValue();
91 MS_EXCEPTION_IF_NULL(obj_arg_value);
92 if (!utils::isa<int64_t>(obj_arg_value)) {
93 MS_EXCEPTION(TypeError) << "Integer argument expected, but got " << obj_arg_value->type_name()
94 << " type value: " << obj_arg_value->ToString();
95 }
96 int64_t index_value = GetValue<int64_t>(obj_arg_value);
97 int64_t insert_position = 0;
98 if (index_value >= len) {
99 insert_position = len;
100 } else if (index_value > 0 && index_value < len) {
101 insert_position = index_value;
102 } else if (index_value < 0 && index_value > -len) {
103 insert_position = len + index_value;
104 }
105 for (int64_t i = 0; i < insert_position; ++i) {
106 auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
107 elems.push_back(value);
108 }
109 elems.push_back(insert_obj_node);
110 for (int64_t i = insert_position; i < len; ++i) {
111 auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
112 elems.push_back(value);
113 }
114 auto out = ret->NewCNode(elems);
115 ret->set_output(out);
116 return ret;
117 }
118
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)119 FuncGraphPtr ListPop::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
120 constexpr size_t list_pop_args_size = 2;
121 abstract::CheckArgsSize("ListPop", args_list, list_pop_args_size);
122 abstract::AbstractListPtr list_input = dyn_cast<abstract::AbstractList>(args_list[0]);
123 AbstractBasePtr pop_index = args_list[1];
124 MS_EXCEPTION_IF_NULL(list_input);
125 size_t list_len = list_input->size();
126 int64_t len = SizeToLong(list_len);
127 FuncGraphPtr ret = std::make_shared<FuncGraph>();
128 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
129 ret->debug_info()->set_name("pop");
130 AnfNodePtr arg0_node = ret->add_parameter();
131 AnfNodePtr pop_index_node = ret->add_parameter();
132
133 std::vector<AnfNodePtr> elems;
134 elems.push_back(NewValueNode(prim::kPrimMakeList));
135 auto pop_index_value = pop_index->BuildValue();
136 if (!utils::isa<int64_t>(pop_index_value)) {
137 MS_EXCEPTION(TypeError) << "Integer argument expected, but got " << pop_index_value->type_name()
138 << " type value: " << pop_index_value->ToString();
139 }
140 int64_t index_value = GetValue<int64_t>(pop_index_value);
141 if (index_value >= len || index_value < -1 * len) {
142 MS_EXCEPTION(IndexError) << "The pop index out of range.";
143 }
144 int64_t pop_position = (index_value >= 0) ? index_value : (len + index_value);
145
146 // List inplace operation do not support:
147 // 1. The python object of list is not found.
148 // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
149 if (fallback::EnableFallbackListDictInplace() && fallback::HasObjInExtraInfoHolder(list_input) &&
150 scope_name().find("VmapRule") == std::string::npos) {
151 MS_LOG(DEBUG) << "Enable inplace operation, convert list pop to InplaceListPop ops.";
152 AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplacePop), arg0_node, pop_index_node};
153 auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
154 list_inplace_node->set_has_side_effect_node(true);
155 ret->set_output(list_inplace_node);
156 ret->set_has_side_effect_node(true);
157 return ret;
158 }
159
160 for (int64_t i = 0; i < pop_position; ++i) {
161 auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
162 elems.push_back(value);
163 }
164 auto pop_node = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(pop_position)});
165 for (int64_t i = pop_position + 1; i < len; ++i) {
166 auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
167 elems.push_back(value);
168 }
169
170 auto new_list = ret->NewCNode(elems);
171 auto out = ret->NewCNode({NewValueNode(prim::kPrimMakeTuple), new_list, pop_node});
172 ret->set_output(out);
173 return ret;
174 }
175
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)176 FuncGraphPtr ListClear::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
177 abstract::CheckArgsSize("ListClear", args_list, 1);
178
179 FuncGraphPtr ret = std::make_shared<FuncGraph>();
180 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
181 ret->debug_info()->set_name("clear");
182 (void)ret->add_parameter();
183
184 auto empty_list = std::vector<ValuePtr>();
185 ret->set_output(NewValueNode(std::make_shared<ValueList>(empty_list)));
186 return ret;
187 }
188
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)189 FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
190 constexpr size_t list_extend_args_size = 2;
191 abstract::CheckArgsSize("ListExtend", args_list, list_extend_args_size);
192
193 FuncGraphPtr ret = std::make_shared<FuncGraph>();
194 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
195 ret->debug_info()->set_name("extend");
196
197 constexpr size_t current_index = 0;
198 constexpr size_t extend_index = 1;
199 auto abs_current = args_list[current_index];
200 auto abs_extend = args_list[extend_index];
201
202 std::vector<AnfNodePtr> elems;
203 elems.push_back(NewValueNode(prim::kPrimMakeList));
204 auto abs_current_list = dyn_cast<abstract::AbstractList>(abs_current);
205 MS_EXCEPTION_IF_NULL(abs_current_list);
206
207 // List inplace operation do not support:
208 // 1. The python object of list is not found.
209 // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
210 if (fallback::EnableFallbackListDictInplace() && fallback::HasObjInExtraInfoHolder(abs_current_list) &&
211 scope_name().find("VmapRule") == std::string::npos) {
212 MS_LOG(DEBUG) << "Enable inplace operation, convert list extend to InplaceListExtend ops.";
213 AnfNodePtr arg0_node = ret->add_parameter();
214 AnfNodePtr arg1_node = ret->add_parameter();
215 AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplaceExtend), arg0_node, arg1_node};
216 auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
217 list_inplace_node->set_has_side_effect_node(true);
218 ret->set_output(list_inplace_node);
219 ret->set_has_side_effect_node(true);
220 return ret;
221 }
222
223 AddNodeToElems(abs_current_list, ret, &elems);
224 AddNodeToElems(abs_extend, ret, &elems);
225
226 auto out = ret->NewCNode(elems);
227 ret->set_output(out);
228 return ret;
229 }
230
AddNodeToElems(const AbstractBasePtr & arg,const FuncGraphPtr & ret,std::vector<AnfNodePtr> * elems)231 void ListExtend::AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems) {
232 AnfNodePtr arg_node = ret->add_parameter();
233 if (arg->isa<abstract::AbstractList>()) {
234 auto arg_list = dyn_cast<abstract::AbstractList>(arg);
235 if (arg_list->dynamic_len()) {
236 MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length list.";
237 }
238 int64_t len = SizeToLong(arg_list->size());
239 for (int64_t i = 0; i < len; ++i) {
240 auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
241 elems->push_back(value);
242 }
243 return;
244 }
245 if (arg->isa<abstract::AbstractTuple>()) {
246 auto arg_tuple = dyn_cast<abstract::AbstractTuple>(arg);
247 if (arg_tuple->dynamic_len()) {
248 MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length tuple.";
249 }
250 int64_t len = SizeToLong(arg_tuple->size());
251 for (int64_t i = 0; i < len; ++i) {
252 auto value = ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), arg_node, NewValueNode(i)});
253 elems->push_back(value);
254 }
255 return;
256 }
257 if (arg->isa<abstract::AbstractTensor>()) {
258 auto abs_tensor = dyn_cast<abstract::AbstractTensor>(arg);
259 auto shape_ptr = abs_tensor->BuildShape();
260 MS_EXCEPTION_IF_NULL(shape_ptr);
261 auto tensor_shape = shape_ptr->cast<abstract::ShapePtr>();
262 MS_EXCEPTION_IF_NULL(tensor_shape);
263 auto shape = tensor_shape->shape();
264 if (shape.empty()) {
265 MS_LOG(EXCEPTION) << "ListExtend does not support scalar tensor.";
266 }
267 if (shape[0] < 0) {
268 MS_LOG(EXCEPTION) << "ListExtend does not support the tensor whose shapes has an uncertain 0th dimension.";
269 }
270 int64_t len = shape[0];
271
272 std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
273 ValuePtr op = prim::GetPythonOps("getitem", module_name);
274 for (int64_t i = 0; i < len; ++i) {
275 auto value = ret->NewCNode({NewValueNode(op), arg_node, NewValueNode(i)});
276 elems->push_back(value);
277 }
278 return;
279 }
280 MS_LOG(EXCEPTION) << "ListExtend supports list, tuple and Tensor, but got " << arg->ToString();
281 }
282
GenerateFuncGraph(const abstract::AbstractBasePtrList & args_list)283 FuncGraphPtr ListReverse::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
284 abstract::CheckArgsSize("ListReverse", args_list, 1);
285 abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(args_list[0]);
286 MS_EXCEPTION_IF_NULL(arg_list);
287 int64_t arg_length = SizeToLong(arg_list->size());
288
289 FuncGraphPtr ret = std::make_shared<FuncGraph>();
290 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
291 ret->debug_info()->set_name("reverse");
292 AnfNodePtr arg_node = ret->add_parameter();
293 // List inplace operation do not support:
294 // 1. The python object of list is not found.
295 // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
296 if (fallback::EnableFallbackListDictInplace() && fallback::HasObjInExtraInfoHolder(arg_list) &&
297 scope_name().find("VmapRule") == std::string::npos) {
298 MS_LOG(DEBUG) << "Enable inplace operation, convert list reverse to InplaceListReverse ops.";
299 AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplaceReverse), arg_node};
300 auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
301 list_inplace_node->set_has_side_effect_node(true);
302 ret->set_output(list_inplace_node);
303 ret->set_has_side_effect_node(true);
304 return ret;
305 }
306
307 std::vector<AnfNodePtr> elems;
308 elems.push_back(NewValueNode(prim::kPrimMakeList));
309 for (int64_t i = arg_length - 1; i >= 0; --i) {
310 elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)}));
311 }
312
313 ret->set_output(ret->NewCNode(elems));
314 return ret;
315 }
316 } // namespace prim
317 } // namespace mindspore
318