• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_VMAP_H
18 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_VMAP_H
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include "ir/meta_func_graph.h"
24 #include "ccsrc/pybind_api/ir/primitive_py.h"
25 
26 namespace mindspore {
27 // namespace to support composite operators definition
28 namespace prim {
29 using CNodeInpusList = std::vector<std::vector<AnfNodePtr>>;
30 using InputsAbstractList = std::vector<std::vector<abstract::AbstractBasePtr>>;
31 constexpr int64_t kValIndex = 0;
32 constexpr int64_t kDimIndex = 1;
33 constexpr char kVmapFunctionModelName[] = "mindspore.ops._vmap";
34 constexpr char kNumpyModelName[] = "mindspore.numpy";
35 class VmapMatchOutAxis : public MetaFuncGraph {
36  public:
VmapMatchOutAxis(const std::string & name)37   explicit VmapMatchOutAxis(const std::string &name) : MetaFuncGraph(name), fg_(std::make_shared<FuncGraph>()) {}
38   ~VmapMatchOutAxis() override = default;
39   MS_DECLARE_PARENT(VmapMatchOutAxis, MetaFuncGraph)
40 
41   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
42 
43  private:
44   CNodePtr GenerateFuncGraphInnerBroadcastAxis(const AnfNodePtr &inputs, const AnfNodePtr &out_axis,
45                                                const AnfNodePtr &axis_size,
46                                                const AbstractBasePtr &inputs_abstract_elements_begin) const;
47   CNodePtr GenerateFuncGraphInnerSingleElement(const AnfNodePtr &inputs, const AnfNodePtr &out_axis,
48                                                const AnfNodePtr &axis_size,
49                                                const AbstractBasePtr &inputs_abstract_elements_end) const;
50   CNodePtr GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inputs, const AnfNodePtr &out_axis,
51                                           const AnfNodePtr &axis_size,
52                                           const AbstractBasePtrList &inputs_abstract_elements,
53                                           const AbstractBasePtr &out_axes_abstract) const;
54   FuncGraphPtr fg_{nullptr};
55 };
56 
57 class VmapGeneralPreprocess : public MetaFuncGraph {
58  public:
VmapGeneralPreprocess(const std::string & name)59   explicit VmapGeneralPreprocess(const std::string &name) : MetaFuncGraph(name) {}
60   ~VmapGeneralPreprocess() override = default;
61   MS_DECLARE_PARENT(VmapGeneralPreprocess, MetaFuncGraph);
62 
63   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
64 };
65 
66 class VmapGeneralRule : public MetaFuncGraph {
67  public:
VmapGeneralRule(const std::string & name,const PrimitivePtr & prim,int64_t axis_size)68   explicit VmapGeneralRule(const std::string &name, const PrimitivePtr &prim, int64_t axis_size)
69       : MetaFuncGraph(name), prim_(prim), axis_size_(axis_size) {}
70   ~VmapGeneralRule() override = default;
71   MS_DECLARE_PARENT(VmapGeneralRule, MetaFuncGraph);
72 
73   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
74 
prim_name()75   std::string prim_name() const {
76     if (!prim_) {
77       return "";
78     }
79     return prim_->name();
80   }
81 
axis_size()82   int64_t axis_size() const { return axis_size_; }
83 
84  private:
85   CNodeInpusList ConstructMapInput(const InputsAbstractList &unfold_elements_abstract, int64_t args_size,
86                                    int64_t tuple_elements_num);
87   PrimitivePtr prim_{nullptr};
88   int64_t axis_size_ = 0;
89   FuncGraphPtr fg_{nullptr};
90 };
91 using VmapGeneralRulePtr = std::shared_ptr<VmapGeneralRule>;
92 
93 class VmapGeneralRulePyAdapter : public VmapGeneralRule {
94  public:
VmapGeneralRulePyAdapter(const std::string & name,const PrimitivePyAdapterPtr & prim,int64_t axis_size)95   explicit VmapGeneralRulePyAdapter(const std::string &name, const PrimitivePyAdapterPtr &prim, int64_t axis_size)
96       : VmapGeneralRule(name, prim->attached_primitive(), axis_size) {}
97   ~VmapGeneralRulePyAdapter() override = default;
98 };
99 }  // namespace prim
100 }  // namespace mindspore
101 
102 #endif  // MINDSPORE_CCSRC_FRONTEND_OPERATOR_VMAP_H
103