1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/operator/composite/vmap.h"
18
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "pybind11/pybind11.h"
24 #include "ir/anf.h"
25 #include "ir/func_graph.h"
26 #include "abstract/abstract_value.h"
27 #include "abstract/abstract_function.h"
28 #include "pipeline/jit/ps/parse/parse_base.h"
29 #include "pipeline/jit/ps/parse/parse.h"
30 #include "pipeline/jit/ps/parse/resolve.h"
31 #include "pipeline/jit/ps/pipeline.h"
32 #include "include/common/utils/python_adapter.h"
33 #include "include/common/pybind_api/api_register.h"
34
35 namespace mindspore {
36 // namespace to support composite operators definition
37 namespace prim {
GenerateFuncGraphAllNone(const FuncGraphPtr & fg,const AnfNodePtr & prim,int64_t args_size,int64_t tuple_elements_num,bool bind)38 void GenerateFuncGraphAllNone(const FuncGraphPtr &fg, const AnfNodePtr &prim, int64_t args_size,
39 int64_t tuple_elements_num, bool bind) {
40 std::vector<AnfNodePtr> prim_output_cnode_inputs;
41 (void)prim_output_cnode_inputs.emplace_back(prim);
42 if (tuple_elements_num != 0) {
43 auto val_in_param = fg->add_parameter();
44 std::vector<AnfNodePtr> prim_inputs_cnode_inputs;
45 (void)prim_inputs_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
46 for (int64_t i = 0; i < tuple_elements_num; ++i) {
47 auto val_in_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(i)});
48 auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_cnode, NewValueNode(kValIndex)});
49 (void)prim_inputs_cnode_inputs.emplace_back(val_cnode);
50 }
51 auto prim_inputs_cnode = fg->NewCNode(prim_inputs_cnode_inputs);
52 (void)prim_output_cnode_inputs.emplace_back(prim_inputs_cnode);
53 args_size = args_size - tuple_elements_num;
54 }
55
56 for (int64_t i = 0; i < args_size; ++i) {
57 auto val_in_param = fg->add_parameter();
58 auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(kValIndex)});
59 (void)prim_output_cnode_inputs.emplace_back(val_cnode);
60 }
61
62 auto prim_output_cnode = fg->NewCNode(prim_output_cnode_inputs);
63 const py::function bind_all_none_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_bind_all_none");
64 auto bind_all_none_fg = parse::ParsePythonCode(bind_all_none_fn);
65 MS_EXCEPTION_IF_NULL(bind_all_none_fg);
66 auto bind_all_none_cnode = fg->NewCNode({NewValueNode(bind_all_none_fg), prim_output_cnode});
67 if (bind) {
68 auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(true), bind_all_none_cnode});
69 fg->set_output(output_cnode);
70 return;
71 }
72 fg->set_output(bind_all_none_cnode);
73 return;
74 }
75
GenerateFuncGraphInnerBroadcastAxis(const AnfNodePtr & inputs,const AnfNodePtr & out_axis,const AnfNodePtr & axis_size,const AbstractBasePtr & inputs_abstract_elements_begin) const76 CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerBroadcastAxis(
77 const AnfNodePtr &inputs, const AnfNodePtr &out_axis, const AnfNodePtr &axis_size,
78 const AbstractBasePtr &inputs_abstract_elements_begin) const {
79 std::vector<AnfNodePtr> value_cnode_inputs;
80 (void)value_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
81 (void)value_cnode_inputs.emplace_back(inputs);
82 (void)value_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(0)));
83 auto value_cnode = fg_->NewCNode(value_cnode_inputs);
84 std::vector<AnfNodePtr> dim_cnode_inputs;
85 (void)dim_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
86 (void)dim_cnode_inputs.emplace_back(inputs);
87 (void)dim_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(1)));
88 auto dim_cnode = fg_->NewCNode(dim_cnode_inputs);
89
90 std::vector<AnfNodePtr> sub_inputs_cnode_inputs;
91 (void)sub_inputs_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
92 auto inputs_abstract_elements_begin_tuple = dyn_cast<abstract::AbstractTuple>(inputs_abstract_elements_begin);
93 auto inputs_abstract_elements_begin_tuple_elements = inputs_abstract_elements_begin_tuple->elements();
94 // inputs: ((x, y), None) -> ((x, None), (y, None)).
95 int64_t begin_tuple_size = static_cast<int64_t>(inputs_abstract_elements_begin_tuple_elements.size());
96 for (int64_t i = 0; i < begin_tuple_size; ++i) {
97 std::vector<AnfNodePtr> cur_tuple_getitem_inputs;
98 (void)cur_tuple_getitem_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
99 (void)cur_tuple_getitem_inputs.emplace_back(value_cnode);
100 (void)cur_tuple_getitem_inputs.emplace_back(NewValueNode(i));
101 auto cur_value_cnode = fg_->NewCNode(cur_tuple_getitem_inputs);
102 std::vector<AnfNodePtr> cur_make_tuple_cnode_inputs;
103 (void)cur_make_tuple_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
104 (void)cur_make_tuple_cnode_inputs.emplace_back(cur_value_cnode);
105 (void)cur_make_tuple_cnode_inputs.emplace_back(dim_cnode);
106 auto cur_make_tuple_cnode = fg_->NewCNode(cur_make_tuple_cnode_inputs);
107 (void)sub_inputs_cnode_inputs.emplace_back(cur_make_tuple_cnode);
108 }
109 auto sub_inputs_cnode = fg_->NewCNode(sub_inputs_cnode_inputs);
110 std::vector<AnfNodePtr> out_cnode_inputs;
111 (void)out_cnode_inputs.emplace_back(NewValueNode(std::make_shared<VmapMatchOutAxis>("VmapMatchOutAxis")));
112 (void)out_cnode_inputs.emplace_back(sub_inputs_cnode);
113 (void)out_cnode_inputs.emplace_back(out_axis);
114 (void)out_cnode_inputs.emplace_back(axis_size);
115 return fg_->NewCNode(out_cnode_inputs);
116 }
117
GenerateFuncGraphInnerSingleElement(const AnfNodePtr & inputs,const AnfNodePtr & out_axis,const AnfNodePtr & axis_size,const AbstractBasePtr & inputs_abstract_elements_end) const118 CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
119 const AnfNodePtr &inputs, const AnfNodePtr &out_axis, const AnfNodePtr &axis_size,
120 const AbstractBasePtr &inputs_abstract_elements_end) const {
121 std::vector<AnfNodePtr> value_cnode_inputs;
122 (void)value_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
123 (void)value_cnode_inputs.emplace_back(inputs);
124 (void)value_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(0)));
125 auto value_cnode = fg_->NewCNode(value_cnode_inputs);
126 std::vector<AnfNodePtr> out_cnode_inputs;
127 if (inputs_abstract_elements_end->isa<abstract::AbstractNone>()) {
128 const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
129 auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
130 MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg);
131 (void)out_cnode_inputs.emplace_back(NewValueNode(broadcast_by_axis_fg));
132 (void)out_cnode_inputs.emplace_back(value_cnode);
133 (void)out_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(0)));
134 (void)out_cnode_inputs.emplace_back(axis_size);
135 } else {
136 std::vector<AnfNodePtr> dim_cnode_inputs;
137 (void)dim_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
138 (void)dim_cnode_inputs.emplace_back(inputs);
139 (void)dim_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(1)));
140 auto dim_cnode = fg_->NewCNode(dim_cnode_inputs);
141 const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
142 auto move_axis_fg = parse::ParsePythonCode(move_axis);
143 MS_EXCEPTION_IF_NULL(move_axis_fg);
144 (void)out_cnode_inputs.emplace_back(NewValueNode(move_axis_fg));
145 (void)out_cnode_inputs.emplace_back(value_cnode);
146 (void)out_cnode_inputs.emplace_back(dim_cnode);
147 (void)out_cnode_inputs.emplace_back(out_axis);
148 }
149 return fg_->NewCNode(out_cnode_inputs);
150 }
151
152 namespace {
GetOutAxesAbstractElements(const AbstractBasePtr & out_axes_abstract,size_t inputs_abstract_elements_size,bool is_out_axes_tuple)153 AbstractBasePtrList GetOutAxesAbstractElements(const AbstractBasePtr &out_axes_abstract,
154 size_t inputs_abstract_elements_size, bool is_out_axes_tuple) {
155 AbstractBasePtrList out_axes_abstract_elements;
156 if (!is_out_axes_tuple) {
157 return out_axes_abstract_elements;
158 }
159 abstract::AbstractTuplePtr out_axes_abstract_tuple = dyn_cast<abstract::AbstractTuple>(out_axes_abstract);
160 out_axes_abstract_elements = out_axes_abstract_tuple->elements();
161 if (out_axes_abstract_elements.size() != inputs_abstract_elements_size) {
162 MS_LOG(EXCEPTION) << "The length of out_axes and inputs do not match. ";
163 }
164 return out_axes_abstract_elements;
165 }
166 } // namespace
167
GenerateFuncGraphInnerAllTuple(const AnfNodePtr & inputs,const AnfNodePtr & out_axis,const AnfNodePtr & axis_size,const AbstractBasePtrList & inputs_abstract_elements,const AbstractBasePtr & out_axes_abstract) const168 CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inputs, const AnfNodePtr &out_axis,
169 const AnfNodePtr &axis_size,
170 const AbstractBasePtrList &inputs_abstract_elements,
171 const AbstractBasePtr &out_axes_abstract) const {
172 bool is_out_axes_tuple = out_axes_abstract->isa<abstract::AbstractTuple>();
173 auto inputs_abstract_elements_size = inputs_abstract_elements.size();
174 AbstractBasePtrList out_axes_abstract_elements =
175 GetOutAxesAbstractElements(out_axes_abstract, inputs_abstract_elements_size, is_out_axes_tuple);
176
177 std::vector<AnfNodePtr> vals_out_tuple_cnode_inputs;
178 (void)vals_out_tuple_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
179 constexpr size_t kEachInputsSize = 2;
180 // inputs: (((x1, x1_axis), (x2, x2_axis)), ((y1, y2), y_axis), (z, z_axis))
181 for (int64_t i = 0; i < static_cast<int64_t>(inputs_abstract_elements_size); ++i) {
182 std::vector<AnfNodePtr> each_input_cnode_inputs;
183 (void)each_input_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
184 (void)each_input_cnode_inputs.emplace_back(inputs);
185 (void)each_input_cnode_inputs.emplace_back(NewValueNode(i));
186 auto each_input_cnode = fg_->NewCNode(each_input_cnode_inputs);
187 AnfNodePtr dst_cnode = nullptr;
188 if (is_out_axes_tuple) {
189 dst_cnode = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_axis, NewValueNode(i)});
190 } else {
191 dst_cnode = out_axis;
192 }
193 auto each_input_abstract = inputs_abstract_elements[i];
194 AbstractBasePtr dst_abstract = is_out_axes_tuple ? out_axes_abstract_elements[i] : out_axes_abstract;
195 auto each_input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(each_input_abstract);
196 MS_EXCEPTION_IF_NULL(each_input_abstract_tuple);
197 auto each_inputs_abstract_elements = each_input_abstract_tuple->elements();
198 auto each_inputs_abstract_elements_size = each_inputs_abstract_elements.size();
199 if (each_inputs_abstract_elements_size == 0) {
200 MS_LOG(INTERNAL_EXCEPTION) << "Each_inputs_abstract_elements_size is empty";
201 }
202 auto each_inputs_abstract_elements_begin = each_inputs_abstract_elements[0];
203 if (each_inputs_abstract_elements_begin->isa<abstract::AbstractTuple>()) {
204 auto each_inputs_abstract_elements_end = each_inputs_abstract_elements.back();
205 if (each_inputs_abstract_elements_end->isa<abstract::AbstractTuple>()) {
206 // current each input: ((x1, x1_axis), (x2, x2_axis)).
207 std::vector<AnfNodePtr> out_cnode_inputs;
208 (void)out_cnode_inputs.emplace_back(NewValueNode(std::make_shared<VmapMatchOutAxis>("VmapMatchOutAxis")));
209 (void)out_cnode_inputs.emplace_back(each_input_cnode);
210 (void)out_cnode_inputs.emplace_back(dst_cnode);
211 (void)out_cnode_inputs.emplace_back(axis_size);
212 (void)vals_out_tuple_cnode_inputs.emplace_back(fg_->NewCNode(out_cnode_inputs));
213 } else {
214 // current each input: ((y1, y2), y_axis).
215 auto out_cnode = GenerateFuncGraphInnerBroadcastAxis(each_input_cnode, dst_cnode, axis_size,
216 each_inputs_abstract_elements_begin);
217 (void)vals_out_tuple_cnode_inputs.emplace_back(out_cnode);
218 }
219 } else {
220 // current each input: (z, z_axis).
221 if (each_inputs_abstract_elements_size != kEachInputsSize) {
222 MS_LOG(EXCEPTION) << "Each input with no tuple should have only two elements.";
223 }
224 auto val_cnode =
225 fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), each_input_cnode, NewValueNode(static_cast<int64_t>(0))});
226 auto src_cnode =
227 fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), each_input_cnode, NewValueNode(static_cast<int64_t>(1))});
228 auto src_abstract = each_inputs_abstract_elements[1];
229 CNodePtr out_cnode = nullptr;
230 if (src_abstract->isa<abstract::AbstractNone>() && !dst_abstract->isa<abstract::AbstractNone>()) {
231 const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
232 auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
233 MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg);
234 out_cnode = fg_->NewCNode({NewValueNode(broadcast_by_axis_fg), val_cnode, dst_cnode, axis_size});
235 } else if (!src_abstract->isa<abstract::AbstractNone>() && dst_abstract->isa<abstract::AbstractNone>()) {
236 MS_LOG(EXCEPTION) << "It is invalid that source is not None and dst is None.";
237 } else if (src_abstract->isa<abstract::AbstractNone>() && dst_abstract->isa<abstract::AbstractNone>()) {
238 out_cnode = val_cnode;
239 } else {
240 const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
241 auto move_axis_fg = parse::ParsePythonCode(move_axis);
242 MS_EXCEPTION_IF_NULL(move_axis_fg);
243 out_cnode = fg_->NewCNode({NewValueNode(move_axis_fg), val_cnode, src_cnode, dst_cnode});
244 }
245 (void)vals_out_tuple_cnode_inputs.emplace_back(out_cnode);
246 }
247 }
248 return fg_->NewCNode(vals_out_tuple_cnode_inputs);
249 }
250
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)251 FuncGraphPtr VmapMatchOutAxis::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
252 auto args_abs_list_size = args_abs_list.size();
253 constexpr size_t kMetaFGInputSize = 3;
254 if (args_abs_list_size != kMetaFGInputSize) {
255 MS_LOG(EXCEPTION) << "The number of inputs to VmapMatchOutAxis should be 3, but got " << args_abs_list_size << ".";
256 }
257 auto inputs_abstract = args_abs_list[kIndex0];
258 auto out_axes_abstract = args_abs_list[kIndex1];
259 auto axis_size_abstract = args_abs_list[kIndex2];
260 MS_EXCEPTION_IF_NULL(inputs_abstract);
261 MS_EXCEPTION_IF_NULL(out_axes_abstract);
262 MS_EXCEPTION_IF_NULL(axis_size_abstract);
263
264 if (!inputs_abstract->isa<abstract::AbstractTuple>()) {
265 MS_LOG(EXCEPTION) << "The first input to VmapMatchOutAxis is vmap_inputs and should be a tuple but got "
266 << inputs_abstract->ToString() << ".";
267 }
268 auto out_axes_abstract_value = out_axes_abstract->BuildValue();
269 if (out_axes_abstract_value == nullptr || out_axes_abstract_value->ContainsValueAny()) {
270 MS_LOG(EXCEPTION) << "The second input to VmapMatchOutAxis is out_axes and should be a constant value.";
271 }
272 auto axis_size_value = axis_size_abstract->BuildValue();
273 if (axis_size_value == nullptr || !axis_size_value->isa<Int64Imm>()) {
274 MS_LOG(EXCEPTION) << "The third input to VmapMatchOutAxis is axis size and should be a constant unsigned int64 "
275 << " value.";
276 }
277 auto inputs = fg_->add_parameter();
278 auto out_axis = fg_->add_parameter();
279 auto axis_size = fg_->add_parameter();
280
281 auto inputs_abstract_tuple = dyn_cast<abstract::AbstractTuple>(inputs_abstract);
282 auto inputs_abstract_elements = inputs_abstract_tuple->elements();
283 auto inputs_abstract_elements_size = inputs_abstract_elements.size();
284 if (inputs_abstract_elements_size == 0) {
285 MS_LOG(EXCEPTION) << "The input to VmapMatchOutAxis is empty";
286 }
287 auto inputs_abstract_elements_begin = inputs_abstract_elements[0];
288 auto inputs_abstract_elements_end = inputs_abstract_elements[inputs_abstract_elements_size - 1];
289 CNodePtr out_cnode = nullptr;
290 constexpr size_t kInputAbstractElementsSize = 2;
291 if (inputs_abstract_elements_begin->isa<abstract::AbstractTuple>() &&
292 inputs_abstract_elements_end->isa<abstract::AbstractTuple>()) {
293 // All elements in inputs are tuple. The format of input is ((x, x_axis), (y, y_axis), (z, z_axis)).
294 out_cnode =
295 GenerateFuncGraphInnerAllTuple(inputs, out_axis, axis_size, inputs_abstract_elements, out_axes_abstract);
296 } else if (inputs_abstract_elements_begin->isa<abstract::AbstractTuple>() &&
297 !inputs_abstract_elements_end->isa<abstract::AbstractTuple>()) {
298 // The last element of input is axis. The format is ((x, y), None).
299 if (inputs_abstract_elements_size != kInputAbstractElementsSize) {
300 MS_LOG(EXCEPTION) << "The length of elements should be 2 but got: " << inputs_abstract_elements_size << ".";
301 }
302 out_cnode = GenerateFuncGraphInnerBroadcastAxis(inputs, out_axis, axis_size, inputs_abstract_elements_begin);
303 } else {
304 // Single tuple element. (x, None)
305 if (inputs_abstract_elements_size != kInputAbstractElementsSize) {
306 MS_LOG(EXCEPTION) << "The length of elements should be 2 but got: " << inputs_abstract_elements_size << ".";
307 }
308 out_cnode = GenerateFuncGraphInnerSingleElement(inputs, out_axis, axis_size, inputs_abstract_elements_end);
309 }
310 fg_->set_output(out_cnode);
311 return fg_;
312 }
313
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)314 FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
315 FuncGraphPtr fg = std::make_shared<FuncGraph>();
316 auto prim = fg->add_parameter();
317 auto args_size = args_abs_list.size();
318 if (args_size <= 1) {
319 MS_LOG(EXCEPTION) << "The length of input to VmapGeneralPreprocess must be greater than 1";
320 }
321 int64_t inputs_size = SizeToLong(args_size - 1);
322 int64_t tuple_elements_num = 0;
323 uint32_t offset = 1;
324 auto get_tuple_elements = [args_size, &tuple_elements_num, &inputs_size,
325 &offset](const AbstractBasePtrList &args_abs_list) -> AbstractBasePtrList {
326 auto arg = args_abs_list[1];
327 if (!arg->isa<abstract::AbstractSequence>()) {
328 MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractSequence but got: "
329 << arg->ToString() << ".";
330 }
331 auto arg_seq = arg->cast<abstract::AbstractSequencePtr>();
332 const auto &arg_tuple_elements = arg_seq->elements();
333 if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) {
334 // Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped
335 // into a tuple. We need to process the internal elements separately and then re-wrap
336 // them into tuple. Handle case such as args:(((A, 0), (B, 1), (C, None)), ...). Which
337 // different from the case with single input parameter ((A, 0),).
338 //
339 // Tuple case:
340 // 1. Only one tuple input: (((A, 0), (B, 1), (C, None)),)
341 // 2. A tuple input and some normal inputs: (((A, 0), (B, 1), (C, None)), (a, 2), (b, 3))
342 tuple_elements_num = arg_tuple_elements.size();
343 inputs_size = tuple_elements_num + inputs_size - 1;
344 offset = 0;
345 AbstractBasePtrList unfold_args_abs_list(arg_tuple_elements.begin(), arg_tuple_elements.end());
346 unfold_args_abs_list.insert(unfold_args_abs_list.end(), args_abs_list.begin() + 2,
347 args_abs_list.end()); // the maybe left inputs.
348 return unfold_args_abs_list;
349 }
350 return args_abs_list;
351 };
352 auto unfold_elements = get_tuple_elements(args_abs_list);
353 bool is_all_none = true;
354 constexpr size_t kCurTupleSize = 2;
355 for (int64_t i = 0; i < inputs_size; ++i) {
356 auto cur_arg = unfold_elements[i + offset];
357 if (!cur_arg->isa<abstract::AbstractTuple>()) {
358 MS_LOG(EXCEPTION) << "The " << i + offset
359 << "th input to VmapGeneralPreprocess should be AbstractTuple but got: " << cur_arg->ToString()
360 << ".";
361 }
362 auto cur_arg_tuple = cur_arg->cast<abstract::AbstractTuplePtr>();
363 auto cur_arg_tuple_elements = cur_arg_tuple->elements();
364 if (cur_arg_tuple_elements.size() != kCurTupleSize) {
365 MS_LOG(EXCEPTION) << "The " << i + offset << "th input to VmapGeneralPreprocess should be a tuple with two "
366 << "elements but got " << cur_arg_tuple_elements.size() << " elements.";
367 }
368 if (!cur_arg_tuple_elements[kDimIndex]->isa<abstract::AbstractNone>()) {
369 MS_LOG(INFO) << "The " << i + offset << "th input to VmapGeneralPreprocess has not None dim value.";
370 is_all_none = false;
371 break;
372 }
373 }
374
375 std::vector<AnfNodePtr> output_cnode_inputs;
376 (void)output_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
377 if (!is_all_none) {
378 for (size_t i = 1; i < args_size; ++i) {
379 (void)fg->add_parameter();
380 }
381 auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(false), NewValueNode(kNone)});
382 fg->set_output(output_cnode);
383 } else {
384 GenerateFuncGraphAllNone(fg, prim, inputs_size, tuple_elements_num, true);
385 }
386 return fg;
387 }
388
389 /// \brief ConstructMapInput.
390 ///
391 /// \param[in] unfold_elements_abstract Unfold elements abstract, such as ((A, 0), (B, 0), (C, None)).
392 /// \param[in] args_size The size of elements.
393 /// \param[in] tuple_elements_num The elements-size for first tuple input.
394 /// \return A vector of AnfNodePtrList, the size is equal to vmap dim size.
ConstructMapInput(const InputsAbstractList & unfold_elements_abstract,int64_t args_size,int64_t tuple_elements_num)395 CNodeInpusList VmapGeneralRule::ConstructMapInput(const InputsAbstractList &unfold_elements_abstract, int64_t args_size,
396 int64_t tuple_elements_num) {
397 AnfNodePtr single_input = nullptr;
398 if (tuple_elements_num != 0) {
399 single_input = fg_->add_parameter();
400 }
401
402 CNodeInpusList map_inputs(axis_size_);
403 for (int64_t i = 0; i < args_size; ++i) {
404 AnfNodePtr cur_arg_node = nullptr;
405 if (i < tuple_elements_num) {
406 cur_arg_node = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), single_input, NewValueNode(i)});
407 } else {
408 cur_arg_node = fg_->add_parameter();
409 }
410 auto unfold_element_abstract = unfold_elements_abstract[i];
411 auto val_abstract = unfold_element_abstract[kValIndex];
412 auto dim_abstract = unfold_element_abstract[kDimIndex];
413 AnfNodePtr val_cnode =
414 fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kValIndex)});
415
416 if (dim_abstract->isa<abstract::AbstractNone>()) {
417 for (int64_t m = 0; m < axis_size_; ++m) {
418 map_inputs[m].push_back(val_cnode);
419 }
420 } else {
421 if (!val_abstract->isa<abstract::AbstractTensor>()) {
422 MS_LOG(EXCEPTION) << "A variable of type other than `Tensor` is accepted, but the source axis is not `None`";
423 }
424 AnfNodePtr dim_cnode =
425 fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kDimIndex)});
426 const py::function unstack_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_unstack");
427 auto unstack_fg_ = parse::ParsePythonCode(unstack_fn);
428 MS_EXCEPTION_IF_NULL(unstack_fg_);
429 auto out_cnode = fg_->NewCNode({NewValueNode(unstack_fg_), dim_cnode, val_cnode});
430 for (int64_t m = 0; m < axis_size_; ++m) {
431 auto out_element_cnode = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(m)});
432 map_inputs[m].push_back(out_element_cnode);
433 }
434 }
435 }
436 return map_inputs;
437 }
438
439 // When the primitive does not registered the relevant specific VmapRule, it attempts to get
440 // this the general rule. The general rule is combining loop and stack operators to simulate
441 // the behavior of Vmap. Noted that, general rules does not guarantee the correctness of
442 // execution results.
443 // Currently, only the following types of primitives are supported:
444 // 1、 Most calculation operations, whose inputs are tensors, scalars or both of them.
445 // (If all elements in a tuple are scalars, it is also considered scalar.)
446 // 2、 Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped into a tuple.
447 // 3、 Operators with indefinite inputs length, whose first inputs is wrapped into a tuple.
448 // In other words, we do not support any tuple wrapped variables except for the special cases
449 // listed above.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)450 FuncGraphPtr VmapGeneralRule::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
451 fg_ = std::make_shared<FuncGraph>();
452 int64_t args_size = static_cast<int64_t>(args_abs_list.size());
453 int64_t tuple_elements_num = 0;
454 auto get_tuple_elements = [&args_size,
455 &tuple_elements_num](const AbstractBasePtrList &args_abs_list) -> AbstractBasePtrList {
456 auto arg = args_abs_list[0];
457 if (!arg->isa<abstract::AbstractTuple>()) {
458 MS_LOG(EXCEPTION) << "The first input to VmapGeneralPreprocess should be AbstractTuple but got: "
459 << arg->ToString() << ".";
460 }
461 auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
462 const auto &arg_tuple_elements = arg_tuple->elements();
463 if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) {
464 // Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped
465 // into a tuple. We need to process the internal elements separately and then re-wrap
466 // them into tuple. Handle case such as args:(((A, 0), (B, 1), (C, None)), ...). Which
467 // different from the case with single input parameter ((A, 0),).
468 //
469 // Tuple case:
470 // 1. Only one tuple input: (((A, 0), (B, 1), (C, None)),)
471 // 2. A tuple input and some normal inputs: (((A, 0), (B, 1), (C, None)), (a, 2), (b, 3))
472 tuple_elements_num = arg_tuple_elements.size();
473 args_size = tuple_elements_num + args_size - 1;
474 AbstractBasePtrList unfold_args_abs_list(arg_tuple_elements.begin(), arg_tuple_elements.end());
475 unfold_args_abs_list.insert(unfold_args_abs_list.end(), args_abs_list.begin() + 1,
476 args_abs_list.end()); // the maybe left inputs.
477 return unfold_args_abs_list;
478 }
479
480 return args_abs_list;
481 };
482 auto unfold_elements = get_tuple_elements(
483 args_abs_list); // ((A, 0), (B, 1), ...), if tuple is the first input, its elements will be unfold.
484
485 bool is_all_none = true;
486 constexpr size_t kCurTupleSize = 2;
487 InputsAbstractList unfold_elements_abstract(args_size);
488 for (int64_t i = 0; i < args_size; ++i) {
489 auto cur_arg = unfold_elements[i];
490 if (!cur_arg->isa<abstract::AbstractTuple>()) {
491 MS_LOG(EXCEPTION) << "The " << i
492 << "th input to VmapGeneralPreprocess should be AbstractTuple but got: " << cur_arg->ToString()
493 << ".";
494 }
495 auto cur_arg_tuple = cur_arg->cast<abstract::AbstractTuplePtr>();
496 auto cur_arg_tuple_elements = cur_arg_tuple->elements();
497 if (cur_arg_tuple_elements.size() != kCurTupleSize) {
498 MS_LOG(EXCEPTION) << "The " << i << "th input to VmapGeneralPreprocess should be a tuple with two "
499 << "elements but got " << cur_arg_tuple_elements.size() << " elements.";
500 }
501 auto dim_abstract = cur_arg_tuple_elements[kDimIndex];
502 if (is_all_none && !dim_abstract->isa<abstract::AbstractNone>()) {
503 MS_LOG(INFO) << "The " << i << "th input to VmapGeneralPreprocess has not None dim value.";
504 is_all_none = false;
505 }
506 auto val_abstract = cur_arg_tuple_elements[kValIndex];
507 std::vector<abstract::AbstractBasePtr> element_abstract = {val_abstract, dim_abstract};
508 unfold_elements_abstract[i] = element_abstract;
509 }
510
511 if (is_all_none) {
512 GenerateFuncGraphAllNone(fg_, NewValueNode(prim_), args_size, tuple_elements_num, false);
513 return fg_;
514 }
515
516 CNodeInpusList map_inputs = ConstructMapInput(unfold_elements_abstract, args_size, tuple_elements_num); //
517
518 std::vector<AnfNodePtr> output_cnode_inputs;
519 (void)output_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
520 for (auto map_input : map_inputs) {
521 std::vector<AnfNodePtr> output_element_cnode_inputs;
522 if (tuple_elements_num != 0) {
523 std::vector<AnfNodePtr> tuple_cnode_inputs{NewValueNode(prim::kPrimMakeTuple)};
524 (void)tuple_cnode_inputs.insert(tuple_cnode_inputs.cend(), map_input.cbegin(),
525 map_input.cbegin() + tuple_elements_num);
526 auto tuple_cnode = fg_->NewCNode(tuple_cnode_inputs);
527 output_element_cnode_inputs.push_back(NewValueNode(prim_));
528 output_element_cnode_inputs.push_back(tuple_cnode);
529 output_element_cnode_inputs.insert(output_element_cnode_inputs.end(), map_input.cbegin() + tuple_elements_num,
530 map_input.cend());
531 } else {
532 output_element_cnode_inputs.push_back(NewValueNode(prim_));
533 (void)output_element_cnode_inputs.insert(output_element_cnode_inputs.cend(), map_input.cbegin(),
534 map_input.cend());
535 }
536 auto output_element_cnode = fg_->NewCNode(output_element_cnode_inputs);
537 (void)output_cnode_inputs.emplace_back(output_element_cnode);
538 }
539 auto output_cnode = fg_->NewCNode(output_cnode_inputs);
540 const py::function vmap_general_output_process_fn =
541 python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_general_output_process");
542 auto vmap_general_output_process_fg_ = parse::ParsePythonCode(vmap_general_output_process_fn);
543 MS_EXCEPTION_IF_NULL(vmap_general_output_process_fg_);
544 auto vmap_general_output = fg_->NewCNode({NewValueNode(vmap_general_output_process_fg_), output_cnode});
545 fg_->set_output(vmap_general_output);
546 return fg_;
547 }
548 } // namespace prim
549 } // namespace mindspore
550