• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/operator/composite/vmap.h"
18 
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "pybind11/pybind11.h"
24 #include "ir/anf.h"
25 #include "ir/func_graph.h"
26 #include "abstract/abstract_value.h"
27 #include "abstract/abstract_function.h"
28 #include "pipeline/jit/ps/parse/parse_base.h"
29 #include "pipeline/jit/ps/parse/parse.h"
30 #include "pipeline/jit/ps/parse/resolve.h"
31 #include "pipeline/jit/ps/pipeline.h"
32 #include "include/common/utils/python_adapter.h"
33 #include "include/common/pybind_api/api_register.h"
34 
35 namespace mindspore {
36 // namespace to support composite operators definition
37 namespace prim {
GenerateFuncGraphAllNone(const FuncGraphPtr & fg,const AnfNodePtr & prim,int64_t args_size,int64_t tuple_elements_num,bool bind)38 void GenerateFuncGraphAllNone(const FuncGraphPtr &fg, const AnfNodePtr &prim, int64_t args_size,
39                               int64_t tuple_elements_num, bool bind) {
40   std::vector<AnfNodePtr> prim_output_cnode_inputs;
41   (void)prim_output_cnode_inputs.emplace_back(prim);
42   if (tuple_elements_num != 0) {
43     auto val_in_param = fg->add_parameter();
44     std::vector<AnfNodePtr> prim_inputs_cnode_inputs;
45     (void)prim_inputs_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
46     for (int64_t i = 0; i < tuple_elements_num; ++i) {
47       auto val_in_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(i)});
48       auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_cnode, NewValueNode(kValIndex)});
49       (void)prim_inputs_cnode_inputs.emplace_back(val_cnode);
50     }
51     auto prim_inputs_cnode = fg->NewCNode(prim_inputs_cnode_inputs);
52     (void)prim_output_cnode_inputs.emplace_back(prim_inputs_cnode);
53     args_size = args_size - tuple_elements_num;
54   }
55 
56   for (int64_t i = 0; i < args_size; ++i) {
57     auto val_in_param = fg->add_parameter();
58     auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(kValIndex)});
59     (void)prim_output_cnode_inputs.emplace_back(val_cnode);
60   }
61 
62   auto prim_output_cnode = fg->NewCNode(prim_output_cnode_inputs);
63   const py::function bind_all_none_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_bind_all_none");
64   auto bind_all_none_fg = parse::ParsePythonCode(bind_all_none_fn);
65   MS_EXCEPTION_IF_NULL(bind_all_none_fg);
66   auto bind_all_none_cnode = fg->NewCNode({NewValueNode(bind_all_none_fg), prim_output_cnode});
67   if (bind) {
68     auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(true), bind_all_none_cnode});
69     fg->set_output(output_cnode);
70     return;
71   }
72   fg->set_output(bind_all_none_cnode);
73   return;
74 }
75 
GenerateFuncGraphInnerBroadcastAxis(const AnfNodePtr & inputs,const AnfNodePtr & out_axis,const AnfNodePtr & axis_size,const AbstractBasePtr & inputs_abstract_elements_begin) const76 CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerBroadcastAxis(
77   const AnfNodePtr &inputs, const AnfNodePtr &out_axis, const AnfNodePtr &axis_size,
78   const AbstractBasePtr &inputs_abstract_elements_begin) const {
79   std::vector<AnfNodePtr> value_cnode_inputs;
80   (void)value_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
81   (void)value_cnode_inputs.emplace_back(inputs);
82   (void)value_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(0)));
83   auto value_cnode = fg_->NewCNode(value_cnode_inputs);
84   std::vector<AnfNodePtr> dim_cnode_inputs;
85   (void)dim_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
86   (void)dim_cnode_inputs.emplace_back(inputs);
87   (void)dim_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(1)));
88   auto dim_cnode = fg_->NewCNode(dim_cnode_inputs);
89 
90   std::vector<AnfNodePtr> sub_inputs_cnode_inputs;
91   (void)sub_inputs_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
92   auto inputs_abstract_elements_begin_tuple = dyn_cast<abstract::AbstractTuple>(inputs_abstract_elements_begin);
93   auto inputs_abstract_elements_begin_tuple_elements = inputs_abstract_elements_begin_tuple->elements();
94   // inputs: ((x, y), None) -> ((x, None), (y, None)).
95   int64_t begin_tuple_size = static_cast<int64_t>(inputs_abstract_elements_begin_tuple_elements.size());
96   for (int64_t i = 0; i < begin_tuple_size; ++i) {
97     std::vector<AnfNodePtr> cur_tuple_getitem_inputs;
98     (void)cur_tuple_getitem_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
99     (void)cur_tuple_getitem_inputs.emplace_back(value_cnode);
100     (void)cur_tuple_getitem_inputs.emplace_back(NewValueNode(i));
101     auto cur_value_cnode = fg_->NewCNode(cur_tuple_getitem_inputs);
102     std::vector<AnfNodePtr> cur_make_tuple_cnode_inputs;
103     (void)cur_make_tuple_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
104     (void)cur_make_tuple_cnode_inputs.emplace_back(cur_value_cnode);
105     (void)cur_make_tuple_cnode_inputs.emplace_back(dim_cnode);
106     auto cur_make_tuple_cnode = fg_->NewCNode(cur_make_tuple_cnode_inputs);
107     (void)sub_inputs_cnode_inputs.emplace_back(cur_make_tuple_cnode);
108   }
109   auto sub_inputs_cnode = fg_->NewCNode(sub_inputs_cnode_inputs);
110   std::vector<AnfNodePtr> out_cnode_inputs;
111   (void)out_cnode_inputs.emplace_back(NewValueNode(std::make_shared<VmapMatchOutAxis>("VmapMatchOutAxis")));
112   (void)out_cnode_inputs.emplace_back(sub_inputs_cnode);
113   (void)out_cnode_inputs.emplace_back(out_axis);
114   (void)out_cnode_inputs.emplace_back(axis_size);
115   return fg_->NewCNode(out_cnode_inputs);
116 }
117 
GenerateFuncGraphInnerSingleElement(const AnfNodePtr & inputs,const AnfNodePtr & out_axis,const AnfNodePtr & axis_size,const AbstractBasePtr & inputs_abstract_elements_end) const118 CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
119   const AnfNodePtr &inputs, const AnfNodePtr &out_axis, const AnfNodePtr &axis_size,
120   const AbstractBasePtr &inputs_abstract_elements_end) const {
121   std::vector<AnfNodePtr> value_cnode_inputs;
122   (void)value_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
123   (void)value_cnode_inputs.emplace_back(inputs);
124   (void)value_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(0)));
125   auto value_cnode = fg_->NewCNode(value_cnode_inputs);
126   std::vector<AnfNodePtr> out_cnode_inputs;
127   if (inputs_abstract_elements_end->isa<abstract::AbstractNone>()) {
128     const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
129     auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
130     MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg);
131     (void)out_cnode_inputs.emplace_back(NewValueNode(broadcast_by_axis_fg));
132     (void)out_cnode_inputs.emplace_back(value_cnode);
133     (void)out_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(0)));
134     (void)out_cnode_inputs.emplace_back(axis_size);
135   } else {
136     std::vector<AnfNodePtr> dim_cnode_inputs;
137     (void)dim_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
138     (void)dim_cnode_inputs.emplace_back(inputs);
139     (void)dim_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(1)));
140     auto dim_cnode = fg_->NewCNode(dim_cnode_inputs);
141     const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
142     auto move_axis_fg = parse::ParsePythonCode(move_axis);
143     MS_EXCEPTION_IF_NULL(move_axis_fg);
144     (void)out_cnode_inputs.emplace_back(NewValueNode(move_axis_fg));
145     (void)out_cnode_inputs.emplace_back(value_cnode);
146     (void)out_cnode_inputs.emplace_back(dim_cnode);
147     (void)out_cnode_inputs.emplace_back(out_axis);
148   }
149   return fg_->NewCNode(out_cnode_inputs);
150 }
151 
152 namespace {
GetOutAxesAbstractElements(const AbstractBasePtr & out_axes_abstract,size_t inputs_abstract_elements_size,bool is_out_axes_tuple)153 AbstractBasePtrList GetOutAxesAbstractElements(const AbstractBasePtr &out_axes_abstract,
154                                                size_t inputs_abstract_elements_size, bool is_out_axes_tuple) {
155   AbstractBasePtrList out_axes_abstract_elements;
156   if (!is_out_axes_tuple) {
157     return out_axes_abstract_elements;
158   }
159   abstract::AbstractTuplePtr out_axes_abstract_tuple = dyn_cast<abstract::AbstractTuple>(out_axes_abstract);
160   out_axes_abstract_elements = out_axes_abstract_tuple->elements();
161   if (out_axes_abstract_elements.size() != inputs_abstract_elements_size) {
162     MS_LOG(EXCEPTION) << "The length of out_axes and inputs do not match. ";
163   }
164   return out_axes_abstract_elements;
165 }
166 }  // namespace
167 
GenerateFuncGraphInnerAllTuple(const AnfNodePtr & inputs,const AnfNodePtr & out_axis,const AnfNodePtr & axis_size,const AbstractBasePtrList & inputs_abstract_elements,const AbstractBasePtr & out_axes_abstract) const168 CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inputs, const AnfNodePtr &out_axis,
169                                                           const AnfNodePtr &axis_size,
170                                                           const AbstractBasePtrList &inputs_abstract_elements,
171                                                           const AbstractBasePtr &out_axes_abstract) const {
172   bool is_out_axes_tuple = out_axes_abstract->isa<abstract::AbstractTuple>();
173   auto inputs_abstract_elements_size = inputs_abstract_elements.size();
174   AbstractBasePtrList out_axes_abstract_elements =
175     GetOutAxesAbstractElements(out_axes_abstract, inputs_abstract_elements_size, is_out_axes_tuple);
176 
177   std::vector<AnfNodePtr> vals_out_tuple_cnode_inputs;
178   (void)vals_out_tuple_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
179   constexpr size_t kEachInputsSize = 2;
180   // inputs: (((x1, x1_axis), (x2, x2_axis)), ((y1, y2), y_axis), (z, z_axis))
181   for (int64_t i = 0; i < static_cast<int64_t>(inputs_abstract_elements_size); ++i) {
182     std::vector<AnfNodePtr> each_input_cnode_inputs;
183     (void)each_input_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
184     (void)each_input_cnode_inputs.emplace_back(inputs);
185     (void)each_input_cnode_inputs.emplace_back(NewValueNode(i));
186     auto each_input_cnode = fg_->NewCNode(each_input_cnode_inputs);
187     AnfNodePtr dst_cnode = nullptr;
188     if (is_out_axes_tuple) {
189       dst_cnode = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_axis, NewValueNode(i)});
190     } else {
191       dst_cnode = out_axis;
192     }
193     auto each_input_abstract = inputs_abstract_elements[i];
194     AbstractBasePtr dst_abstract = is_out_axes_tuple ? out_axes_abstract_elements[i] : out_axes_abstract;
195     auto each_input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(each_input_abstract);
196     MS_EXCEPTION_IF_NULL(each_input_abstract_tuple);
197     auto each_inputs_abstract_elements = each_input_abstract_tuple->elements();
198     auto each_inputs_abstract_elements_size = each_inputs_abstract_elements.size();
199     if (each_inputs_abstract_elements_size == 0) {
200       MS_LOG(INTERNAL_EXCEPTION) << "Each_inputs_abstract_elements_size is empty";
201     }
202     auto each_inputs_abstract_elements_begin = each_inputs_abstract_elements[0];
203     if (each_inputs_abstract_elements_begin->isa<abstract::AbstractTuple>()) {
204       auto each_inputs_abstract_elements_end = each_inputs_abstract_elements.back();
205       if (each_inputs_abstract_elements_end->isa<abstract::AbstractTuple>()) {
206         // current each input: ((x1, x1_axis), (x2, x2_axis)).
207         std::vector<AnfNodePtr> out_cnode_inputs;
208         (void)out_cnode_inputs.emplace_back(NewValueNode(std::make_shared<VmapMatchOutAxis>("VmapMatchOutAxis")));
209         (void)out_cnode_inputs.emplace_back(each_input_cnode);
210         (void)out_cnode_inputs.emplace_back(dst_cnode);
211         (void)out_cnode_inputs.emplace_back(axis_size);
212         (void)vals_out_tuple_cnode_inputs.emplace_back(fg_->NewCNode(out_cnode_inputs));
213       } else {
214         // current each input: ((y1, y2), y_axis).
215         auto out_cnode = GenerateFuncGraphInnerBroadcastAxis(each_input_cnode, dst_cnode, axis_size,
216                                                              each_inputs_abstract_elements_begin);
217         (void)vals_out_tuple_cnode_inputs.emplace_back(out_cnode);
218       }
219     } else {
220       // current each input: (z, z_axis).
221       if (each_inputs_abstract_elements_size != kEachInputsSize) {
222         MS_LOG(EXCEPTION) << "Each input with no tuple should have only two elements.";
223       }
224       auto val_cnode =
225         fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), each_input_cnode, NewValueNode(static_cast<int64_t>(0))});
226       auto src_cnode =
227         fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), each_input_cnode, NewValueNode(static_cast<int64_t>(1))});
228       auto src_abstract = each_inputs_abstract_elements[1];
229       CNodePtr out_cnode = nullptr;
230       if (src_abstract->isa<abstract::AbstractNone>() && !dst_abstract->isa<abstract::AbstractNone>()) {
231         const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
232         auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
233         MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg);
234         out_cnode = fg_->NewCNode({NewValueNode(broadcast_by_axis_fg), val_cnode, dst_cnode, axis_size});
235       } else if (!src_abstract->isa<abstract::AbstractNone>() && dst_abstract->isa<abstract::AbstractNone>()) {
236         MS_LOG(EXCEPTION) << "It is invalid that source is not None and dst is None.";
237       } else if (src_abstract->isa<abstract::AbstractNone>() && dst_abstract->isa<abstract::AbstractNone>()) {
238         out_cnode = val_cnode;
239       } else {
240         const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
241         auto move_axis_fg = parse::ParsePythonCode(move_axis);
242         MS_EXCEPTION_IF_NULL(move_axis_fg);
243         out_cnode = fg_->NewCNode({NewValueNode(move_axis_fg), val_cnode, src_cnode, dst_cnode});
244       }
245       (void)vals_out_tuple_cnode_inputs.emplace_back(out_cnode);
246     }
247   }
248   return fg_->NewCNode(vals_out_tuple_cnode_inputs);
249 }
250 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)251 FuncGraphPtr VmapMatchOutAxis::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
252   auto args_abs_list_size = args_abs_list.size();
253   constexpr size_t kMetaFGInputSize = 3;
254   if (args_abs_list_size != kMetaFGInputSize) {
255     MS_LOG(EXCEPTION) << "The number of inputs to VmapMatchOutAxis should be 3, but got " << args_abs_list_size << ".";
256   }
257   auto inputs_abstract = args_abs_list[kIndex0];
258   auto out_axes_abstract = args_abs_list[kIndex1];
259   auto axis_size_abstract = args_abs_list[kIndex2];
260   MS_EXCEPTION_IF_NULL(inputs_abstract);
261   MS_EXCEPTION_IF_NULL(out_axes_abstract);
262   MS_EXCEPTION_IF_NULL(axis_size_abstract);
263 
264   if (!inputs_abstract->isa<abstract::AbstractTuple>()) {
265     MS_LOG(EXCEPTION) << "The first input to VmapMatchOutAxis is vmap_inputs and should be a tuple but got "
266                       << inputs_abstract->ToString() << ".";
267   }
268   auto out_axes_abstract_value = out_axes_abstract->BuildValue();
269   if (out_axes_abstract_value == nullptr || out_axes_abstract_value->ContainsValueAny()) {
270     MS_LOG(EXCEPTION) << "The second input to VmapMatchOutAxis is out_axes and should be a constant value.";
271   }
272   auto axis_size_value = axis_size_abstract->BuildValue();
273   if (axis_size_value == nullptr || !axis_size_value->isa<Int64Imm>()) {
274     MS_LOG(EXCEPTION) << "The third input to VmapMatchOutAxis is axis size and should be a constant unsigned int64 "
275                       << " value.";
276   }
277   auto inputs = fg_->add_parameter();
278   auto out_axis = fg_->add_parameter();
279   auto axis_size = fg_->add_parameter();
280 
281   auto inputs_abstract_tuple = dyn_cast<abstract::AbstractTuple>(inputs_abstract);
282   auto inputs_abstract_elements = inputs_abstract_tuple->elements();
283   auto inputs_abstract_elements_size = inputs_abstract_elements.size();
284   if (inputs_abstract_elements_size == 0) {
285     MS_LOG(EXCEPTION) << "The input to VmapMatchOutAxis is empty";
286   }
287   auto inputs_abstract_elements_begin = inputs_abstract_elements[0];
288   auto inputs_abstract_elements_end = inputs_abstract_elements[inputs_abstract_elements_size - 1];
289   CNodePtr out_cnode = nullptr;
290   constexpr size_t kInputAbstractElementsSize = 2;
291   if (inputs_abstract_elements_begin->isa<abstract::AbstractTuple>() &&
292       inputs_abstract_elements_end->isa<abstract::AbstractTuple>()) {
293     // All elements in inputs are tuple. The format of input is ((x, x_axis), (y, y_axis), (z, z_axis)).
294     out_cnode =
295       GenerateFuncGraphInnerAllTuple(inputs, out_axis, axis_size, inputs_abstract_elements, out_axes_abstract);
296   } else if (inputs_abstract_elements_begin->isa<abstract::AbstractTuple>() &&
297              !inputs_abstract_elements_end->isa<abstract::AbstractTuple>()) {
298     // The last element of input is axis. The format is ((x, y), None).
299     if (inputs_abstract_elements_size != kInputAbstractElementsSize) {
300       MS_LOG(EXCEPTION) << "The length of elements should be 2 but got: " << inputs_abstract_elements_size << ".";
301     }
302     out_cnode = GenerateFuncGraphInnerBroadcastAxis(inputs, out_axis, axis_size, inputs_abstract_elements_begin);
303   } else {
304     // Single tuple element. (x, None)
305     if (inputs_abstract_elements_size != kInputAbstractElementsSize) {
306       MS_LOG(EXCEPTION) << "The length of elements should be 2 but got: " << inputs_abstract_elements_size << ".";
307     }
308     out_cnode = GenerateFuncGraphInnerSingleElement(inputs, out_axis, axis_size, inputs_abstract_elements_end);
309   }
310   fg_->set_output(out_cnode);
311   return fg_;
312 }
313 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)314 FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
315   FuncGraphPtr fg = std::make_shared<FuncGraph>();
316   auto prim = fg->add_parameter();
317   auto args_size = args_abs_list.size();
318   if (args_size <= 1) {
319     MS_LOG(EXCEPTION) << "The length of input to VmapGeneralPreprocess must be greater than 1";
320   }
321   int64_t inputs_size = SizeToLong(args_size - 1);
322   int64_t tuple_elements_num = 0;
323   uint32_t offset = 1;
324   auto get_tuple_elements = [args_size, &tuple_elements_num, &inputs_size,
325                              &offset](const AbstractBasePtrList &args_abs_list) -> AbstractBasePtrList {
326     auto arg = args_abs_list[1];
327     if (!arg->isa<abstract::AbstractSequence>()) {
328       MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractSequence but got: "
329                         << arg->ToString() << ".";
330     }
331     auto arg_seq = arg->cast<abstract::AbstractSequencePtr>();
332     const auto &arg_tuple_elements = arg_seq->elements();
333     if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) {
334       // Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped
335       // into a tuple. We need to process the internal elements separately and then re-wrap
336       // them into tuple. Handle case such as args:(((A, 0), (B, 1), (C, None)), ...). Which
337       // different from the case with single input parameter ((A, 0),).
338       //
339       // Tuple case:
340       // 1. Only one tuple input: (((A, 0), (B, 1), (C, None)),)
341       // 2. A tuple input and some normal inputs: (((A, 0), (B, 1), (C, None)), (a, 2), (b, 3))
342       tuple_elements_num = arg_tuple_elements.size();
343       inputs_size = tuple_elements_num + inputs_size - 1;
344       offset = 0;
345       AbstractBasePtrList unfold_args_abs_list(arg_tuple_elements.begin(), arg_tuple_elements.end());
346       unfold_args_abs_list.insert(unfold_args_abs_list.end(), args_abs_list.begin() + 2,
347                                   args_abs_list.end());  // the maybe left inputs.
348       return unfold_args_abs_list;
349     }
350     return args_abs_list;
351   };
352   auto unfold_elements = get_tuple_elements(args_abs_list);
353   bool is_all_none = true;
354   constexpr size_t kCurTupleSize = 2;
355   for (int64_t i = 0; i < inputs_size; ++i) {
356     auto cur_arg = unfold_elements[i + offset];
357     if (!cur_arg->isa<abstract::AbstractTuple>()) {
358       MS_LOG(EXCEPTION) << "The " << i + offset
359                         << "th input to VmapGeneralPreprocess should be AbstractTuple but got: " << cur_arg->ToString()
360                         << ".";
361     }
362     auto cur_arg_tuple = cur_arg->cast<abstract::AbstractTuplePtr>();
363     auto cur_arg_tuple_elements = cur_arg_tuple->elements();
364     if (cur_arg_tuple_elements.size() != kCurTupleSize) {
365       MS_LOG(EXCEPTION) << "The " << i + offset << "th input to VmapGeneralPreprocess should be a tuple with two "
366                         << "elements but got " << cur_arg_tuple_elements.size() << " elements.";
367     }
368     if (!cur_arg_tuple_elements[kDimIndex]->isa<abstract::AbstractNone>()) {
369       MS_LOG(INFO) << "The " << i + offset << "th input to VmapGeneralPreprocess has not None dim value.";
370       is_all_none = false;
371       break;
372     }
373   }
374 
375   std::vector<AnfNodePtr> output_cnode_inputs;
376   (void)output_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
377   if (!is_all_none) {
378     for (size_t i = 1; i < args_size; ++i) {
379       (void)fg->add_parameter();
380     }
381     auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(false), NewValueNode(kNone)});
382     fg->set_output(output_cnode);
383   } else {
384     GenerateFuncGraphAllNone(fg, prim, inputs_size, tuple_elements_num, true);
385   }
386   return fg;
387 }
388 
389 /// \brief ConstructMapInput.
390 ///
391 /// \param[in] unfold_elements_abstract Unfold elements abstract, such as ((A, 0), (B, 0), (C, None)).
392 /// \param[in] args_size The size of elements.
393 /// \param[in] tuple_elements_num The elements-size for first tuple input.
394 /// \return A vector of AnfNodePtrList, the size is equal to vmap dim size.
ConstructMapInput(const InputsAbstractList & unfold_elements_abstract,int64_t args_size,int64_t tuple_elements_num)395 CNodeInpusList VmapGeneralRule::ConstructMapInput(const InputsAbstractList &unfold_elements_abstract, int64_t args_size,
396                                                   int64_t tuple_elements_num) {
397   AnfNodePtr single_input = nullptr;
398   if (tuple_elements_num != 0) {
399     single_input = fg_->add_parameter();
400   }
401 
402   CNodeInpusList map_inputs(axis_size_);
403   for (int64_t i = 0; i < args_size; ++i) {
404     AnfNodePtr cur_arg_node = nullptr;
405     if (i < tuple_elements_num) {
406       cur_arg_node = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), single_input, NewValueNode(i)});
407     } else {
408       cur_arg_node = fg_->add_parameter();
409     }
410     auto unfold_element_abstract = unfold_elements_abstract[i];
411     auto val_abstract = unfold_element_abstract[kValIndex];
412     auto dim_abstract = unfold_element_abstract[kDimIndex];
413     AnfNodePtr val_cnode =
414       fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kValIndex)});
415 
416     if (dim_abstract->isa<abstract::AbstractNone>()) {
417       for (int64_t m = 0; m < axis_size_; ++m) {
418         map_inputs[m].push_back(val_cnode);
419       }
420     } else {
421       if (!val_abstract->isa<abstract::AbstractTensor>()) {
422         MS_LOG(EXCEPTION) << "A variable of type other than `Tensor` is accepted, but the source axis is not `None`";
423       }
424       AnfNodePtr dim_cnode =
425         fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kDimIndex)});
426       const py::function unstack_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_unstack");
427       auto unstack_fg_ = parse::ParsePythonCode(unstack_fn);
428       MS_EXCEPTION_IF_NULL(unstack_fg_);
429       auto out_cnode = fg_->NewCNode({NewValueNode(unstack_fg_), dim_cnode, val_cnode});
430       for (int64_t m = 0; m < axis_size_; ++m) {
431         auto out_element_cnode = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(m)});
432         map_inputs[m].push_back(out_element_cnode);
433       }
434     }
435   }
436   return map_inputs;
437 }
438 
439 // When the primitive does not registered the relevant specific VmapRule, it attempts to get
440 // this the general rule. The general rule is combining loop and stack operators to simulate
441 // the behavior of Vmap. Noted that, general rules does not guarantee the correctness of
442 // execution results.
443 // Currently, only the following types of primitives are supported:
444 // 1、 Most calculation operations, whose inputs are tensors, scalars or both of them.
445 // (If all elements in a tuple are scalars, it is also considered scalar.)
446 // 2、 Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped into a tuple.
447 // 3、 Operators with indefinite inputs length, whose first inputs is wrapped into a tuple.
448 // In other words, we do not support any tuple wrapped variables except for the special cases
449 //   listed above.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)450 FuncGraphPtr VmapGeneralRule::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
451   fg_ = std::make_shared<FuncGraph>();
452   int64_t args_size = static_cast<int64_t>(args_abs_list.size());
453   int64_t tuple_elements_num = 0;
454   auto get_tuple_elements = [&args_size,
455                              &tuple_elements_num](const AbstractBasePtrList &args_abs_list) -> AbstractBasePtrList {
456     auto arg = args_abs_list[0];
457     if (!arg->isa<abstract::AbstractTuple>()) {
458       MS_LOG(EXCEPTION) << "The first input to VmapGeneralPreprocess should be AbstractTuple but got: "
459                         << arg->ToString() << ".";
460     }
461     auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
462     const auto &arg_tuple_elements = arg_tuple->elements();
463     if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) {
464       // Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped
465       // into a tuple. We need to process the internal elements separately and then re-wrap
466       // them into tuple. Handle case such as args:(((A, 0), (B, 1), (C, None)), ...). Which
467       // different from the case with single input parameter ((A, 0),).
468       //
469       // Tuple case:
470       // 1. Only one tuple input: (((A, 0), (B, 1), (C, None)),)
471       // 2. A tuple input and some normal inputs: (((A, 0), (B, 1), (C, None)), (a, 2), (b, 3))
472       tuple_elements_num = arg_tuple_elements.size();
473       args_size = tuple_elements_num + args_size - 1;
474       AbstractBasePtrList unfold_args_abs_list(arg_tuple_elements.begin(), arg_tuple_elements.end());
475       unfold_args_abs_list.insert(unfold_args_abs_list.end(), args_abs_list.begin() + 1,
476                                   args_abs_list.end());  // the maybe left inputs.
477       return unfold_args_abs_list;
478     }
479 
480     return args_abs_list;
481   };
482   auto unfold_elements = get_tuple_elements(
483     args_abs_list);  // ((A, 0), (B, 1), ...), if tuple is the first input, its elements will be unfold.
484 
485   bool is_all_none = true;
486   constexpr size_t kCurTupleSize = 2;
487   InputsAbstractList unfold_elements_abstract(args_size);
488   for (int64_t i = 0; i < args_size; ++i) {
489     auto cur_arg = unfold_elements[i];
490     if (!cur_arg->isa<abstract::AbstractTuple>()) {
491       MS_LOG(EXCEPTION) << "The " << i
492                         << "th input to VmapGeneralPreprocess should be AbstractTuple but got: " << cur_arg->ToString()
493                         << ".";
494     }
495     auto cur_arg_tuple = cur_arg->cast<abstract::AbstractTuplePtr>();
496     auto cur_arg_tuple_elements = cur_arg_tuple->elements();
497     if (cur_arg_tuple_elements.size() != kCurTupleSize) {
498       MS_LOG(EXCEPTION) << "The " << i << "th input to VmapGeneralPreprocess should be a tuple with two "
499                         << "elements but got " << cur_arg_tuple_elements.size() << " elements.";
500     }
501     auto dim_abstract = cur_arg_tuple_elements[kDimIndex];
502     if (is_all_none && !dim_abstract->isa<abstract::AbstractNone>()) {
503       MS_LOG(INFO) << "The " << i << "th input to VmapGeneralPreprocess has not None dim value.";
504       is_all_none = false;
505     }
506     auto val_abstract = cur_arg_tuple_elements[kValIndex];
507     std::vector<abstract::AbstractBasePtr> element_abstract = {val_abstract, dim_abstract};
508     unfold_elements_abstract[i] = element_abstract;
509   }
510 
511   if (is_all_none) {
512     GenerateFuncGraphAllNone(fg_, NewValueNode(prim_), args_size, tuple_elements_num, false);
513     return fg_;
514   }
515 
516   CNodeInpusList map_inputs = ConstructMapInput(unfold_elements_abstract, args_size, tuple_elements_num);  //
517 
518   std::vector<AnfNodePtr> output_cnode_inputs;
519   (void)output_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
520   for (auto map_input : map_inputs) {
521     std::vector<AnfNodePtr> output_element_cnode_inputs;
522     if (tuple_elements_num != 0) {
523       std::vector<AnfNodePtr> tuple_cnode_inputs{NewValueNode(prim::kPrimMakeTuple)};
524       (void)tuple_cnode_inputs.insert(tuple_cnode_inputs.cend(), map_input.cbegin(),
525                                       map_input.cbegin() + tuple_elements_num);
526       auto tuple_cnode = fg_->NewCNode(tuple_cnode_inputs);
527       output_element_cnode_inputs.push_back(NewValueNode(prim_));
528       output_element_cnode_inputs.push_back(tuple_cnode);
529       output_element_cnode_inputs.insert(output_element_cnode_inputs.end(), map_input.cbegin() + tuple_elements_num,
530                                          map_input.cend());
531     } else {
532       output_element_cnode_inputs.push_back(NewValueNode(prim_));
533       (void)output_element_cnode_inputs.insert(output_element_cnode_inputs.cend(), map_input.cbegin(),
534                                                map_input.cend());
535     }
536     auto output_element_cnode = fg_->NewCNode(output_element_cnode_inputs);
537     (void)output_cnode_inputs.emplace_back(output_element_cnode);
538   }
539   auto output_cnode = fg_->NewCNode(output_cnode_inputs);
540   const py::function vmap_general_output_process_fn =
541     python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_general_output_process");
542   auto vmap_general_output_process_fg_ = parse::ParsePythonCode(vmap_general_output_process_fn);
543   MS_EXCEPTION_IF_NULL(vmap_general_output_process_fg_);
544   auto vmap_general_output = fg_->NewCNode({NewValueNode(vmap_general_output_process_fg_), output_cnode});
545   fg_->set_output(vmap_general_output);
546   return fg_;
547 }
548 }  // namespace prim
549 }  // namespace mindspore
550