• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
20 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
21 
22 #include <vector>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <map>
27 #include <set>
28 #include <memory>
29 #include "frontend/operator/composite/zip_operation.h"
30 #include "frontend/operator/composite/list_append_operation.h"
31 #include "frontend/operator/composite/do_signature.h"
32 #include "frontend/operator/composite/unpack_call.h"
33 #include "frontend/operator/composite/multitype_funcgraph.h"
34 #include "pipeline/jit/static_analysis/static_analysis.h"
35 #include "utils/misc.h"
36 #include "utils/any.h"
37 #include "ir/dtype.h"
38 #include "ir/meta_func_graph.h"
39 
40 namespace mindspore {
41 // namespace to support composite operators definition
42 namespace prim {
43 using AbstractSlicePtr = abstract::AbstractSlicePtr;
44 using AbstractScalarPtr = abstract::AbstractScalarPtr;
45 using AbstractTensorPtr = abstract::AbstractTensorPtr;
46 using ElemwiseMap = std::unordered_map<std::string, PrimitivePtr>;
47 using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
48 
49 class HyperMap : public MetaFuncGraph {
50  public:
51   explicit HyperMap(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
52   HyperMap(const HyperMap &h);
53   void Init();
54   HyperMap &operator=(const HyperMap &h) {
55     if (this != &h) {
56       fn_leaf_ = h.fn_leaf_;
57       reverse_ = h.reverse_;
58       broadcast_ = h.broadcast_;
59       nonleaf_ = h.nonleaf_;
60       if (fn_leaf_) {
61         name_ = "hyper_map[" + fn_leaf_->name() + "]";
62       }
63     }
64     return *this;
65   }
66   ~HyperMap() override = default;
67   MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
68 
69   abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
70   FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
GetFnLeaf()71   MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
72 
73  private:
74   AnfNodePtr FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
75   AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
76                       const ArgsPairList &arg_map);
77   AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
78                       const ArgsPairList &arg_map);
79   AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
80                       const ArgsPairList &arg_map);
81   AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
82   ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
83 
84   MultitypeFuncGraphPtr fn_leaf_;
85   bool reverse_;
86   bool broadcast_;
87   std::set<TypeId> nonleaf_;
88 };
89 using HyperMapPtr = std::shared_ptr<HyperMap>;
90 
91 class HyperMapPy : public HyperMap {
92  public:
93   explicit HyperMapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
HyperMap(reverse,fn_leaf)94       : HyperMap(reverse, fn_leaf) {}
95   ~HyperMapPy() override = default;
96   MS_DECLARE_PARENT(HyperMapPy, HyperMap)
97 };
98 using HyperMapPyPtr = std::shared_ptr<HyperMapPy>;
99 
100 extern ValuePtr kCompositeHyperMap;
101 
102 enum TailType { kGradAll, kGradFirst, kNotGrad };
103 
104 class Tail : public MetaFuncGraph {
105  public:
106   explicit Tail(const std::string &name, TailType tail_type = kNotGrad)
MetaFuncGraph(name)107       : MetaFuncGraph(name), tail_type_(tail_type), enable_tuple_grad_(false) {}
108   ~Tail() override = default;
109   MS_DECLARE_PARENT(Tail, MetaFuncGraph)
110 
111   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
112   FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const;
113 
114   friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
set_enable_tuple_grad(bool enable_tuple_grad)115   void set_enable_tuple_grad(bool enable_tuple_grad) { enable_tuple_grad_ = enable_tuple_grad; }
116 
117  private:
118   TailType tail_type_;
119   bool enable_tuple_grad_;
120 };
121 using TailPtr = std::shared_ptr<Tail>;
122 
123 class MakeTupleGradient : public MetaFuncGraph {
124  public:
MakeTupleGradient(const std::string & name)125   explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {}
126   ~MakeTupleGradient() override = default;
127   MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
128   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
129   friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; }
130 };
131 using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
132 
133 class MakeListGradient : public MetaFuncGraph {
134  public:
MakeListGradient(const std::string & name)135   explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {}
136   ~MakeListGradient() override = default;
137   MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph)
138   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
139   friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; }
140 };
141 using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
142 
143 class GradOperation : public MetaFuncGraph {
144  public:
145   explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
146                          bool sens_param = false);
147   ~GradOperation() override = default;
148   MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
149 
150   FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
151                        const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
152                        const std::vector<AnfNodePtr> &weight_args = {});
153 
154   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
sens_param()155   bool sens_param() const { return sens_param_; }
156   bool get_all_;
157   bool get_by_list_;
158   bool sens_param_;
159 
160  private:
161   void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
162                        const AnfNodePtr &weights, bool enable_tuple_grad);
163 };
164 using GradOperationPtr = std::shared_ptr<GradOperation>;
165 
166 class ListMap {
167  public:
ListMap(const std::string & name)168   explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
169   ~ListMap() = default;
170   void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr);
171   void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr);
172   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list);
173 
174  private:
175   std::string name_;
176   std::map<std::vector<AnyPtr>, FuncGraphPtr> cache_;
177 };
178 
179 class TupleAdd : public MetaFuncGraph {
180  public:
TupleAdd(const std::string & name)181   explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {}
182   ~TupleAdd() override = default;
183   MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
184   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
185   friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; }
186 };
187 using TupleAddPtr = std::shared_ptr<TupleAdd>;
188 
189 class TupleSlice : public MetaFuncGraph {
190  public:
TupleSlice(const std::string & name)191   explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
192   ~TupleSlice() override = default;
193   MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
194   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
195   friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
196 };
197 using TupleSlicePtr = std::shared_ptr<TupleSlice>;
198 
199 class TupleGetItemTensor : public MetaFuncGraph {
200  public:
TupleGetItemTensor(const std::string & name)201   explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
202   ~TupleGetItemTensor() override = default;
203   MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph)
204   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
205   friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) {
206     return lhs.name_ == rhs.name_;
207   }
208 };
209 using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>;
210 }  // namespace prim
211 }  // namespace mindspore
212 
213 #endif  // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
214