• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
20 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
21 
22 #include <vector>
23 #include <string>
24 #include <utility>
25 #include <map>
26 #include <set>
27 #include <memory>
28 #include "utils/hash_map.h"
29 #include "frontend/operator/composite/zip_operation.h"
30 #include "frontend/operator/composite/list_operation.h"
31 #include "frontend/operator/composite/do_signature.h"
32 #include "frontend/operator/composite/unpack_call.h"
33 #include "frontend/operator/composite/multitype_funcgraph.h"
34 #include "frontend/operator/composite/starred_operation.h"
35 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
36 #include "utils/misc.h"
37 #include "utils/any.h"
38 #include "ir/dtype.h"
39 #include "ir/meta_func_graph.h"
40 
41 namespace mindspore {
42 // namespace to support composite operators definition
43 namespace prim {
44 using AbstractSlicePtr = abstract::AbstractSlicePtr;
45 using AbstractScalarPtr = abstract::AbstractScalarPtr;
46 using AbstractTensorPtr = abstract::AbstractTensorPtr;
47 using ElemwiseMap = mindspore::HashMap<std::string, PrimitivePtr>;
48 using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
49 using AbstractListPtr = abstract::AbstractListPtr;
50 
51 class HyperMap : public MetaFuncGraph {
52  public:
53   explicit HyperMap(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
54   HyperMap(const HyperMap &h);
55   void Init();
56   HyperMap &operator=(const HyperMap &h) noexcept {
57     if (this != &h) {
58       fn_leaf_ = h.fn_leaf_;
59       reverse_ = h.reverse_;
60       nonleaf_ = h.nonleaf_;
61       if (fn_leaf_) {
62         name_ = "hyper_map[" + fn_leaf_->name() + "]";
63       }
64     }
65     return *this;
66   }
67   ~HyperMap() override = default;
68   MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
69 
70   abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_abs_list) const override;
71   FuncGraphPtr GenerateFromTypes(const TypePtrList &args_abs_list) override;
GetFnLeaf()72   MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
73   void SetObjectForFnLeaf(const py::object &leaf_object);
74 
75  private:
76   AnfNodePtr FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const;
77   AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
78                       const ArgsPairList &arg_map) const;
79   AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
80                       const ArgsPairList &arg_map) const;
81   AnfNodePtr FullMake(const std::shared_ptr<Dictionary> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
82                       const ArgsPairList &arg_map) const;
83   AnfNodePtr Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const;
84   std::pair<std::string, std::string> GetHyperMapInputIndex(size_t num) const;
85 
86   MultitypeFuncGraphPtr fn_leaf_;
87   bool reverse_;
88   std::set<TypeId> nonleaf_;
89 };
90 using HyperMapPtr = std::shared_ptr<HyperMap>;
91 
92 class HyperMapPy : public HyperMap {
93  public:
94   explicit HyperMapPy(bool reverse = false, const py::object &fn_leaf = py::none())
95       : HyperMap(reverse, fn_leaf.cast<prim::MultitypeFuncGraphPtr>()) {
96     SetObjectForFnLeaf(fn_leaf);
97   }
98   ~HyperMapPy() override = default;
99   MS_DECLARE_PARENT(HyperMapPy, HyperMap)
100 };
101 using HyperMapPyPtr = std::shared_ptr<HyperMapPy>;
102 
103 extern ValuePtr kCompositeHyperMap;
104 
105 enum TailType { kGradAll, kGradFirst, kGradByPosition, kNotGrad };
106 
107 class Tail : public MetaFuncGraph {
108  public:
109   explicit Tail(const std::string &name, TailType tail_type = kNotGrad, bool return_ids = false)
MetaFuncGraph(name)110       : MetaFuncGraph(name), tail_type_(tail_type), enable_tuple_grad_first_(false), return_ids_(return_ids) {}
111   ~Tail() override = default;
112   MS_DECLARE_PARENT(Tail, MetaFuncGraph)
113 
114   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
115 
116   friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
set_enable_tuple_grad_first(bool enable_tuple_grad_first)117   void set_enable_tuple_grad_first(bool enable_tuple_grad_first) { enable_tuple_grad_first_ = enable_tuple_grad_first; }
118 
119  private:
120   FuncGraphPtr GenerateTailFuncGraph(const abstract::AbstractSequencePtr &sequence_arg) const;
121   FuncGraphPtr GenerateGradFuncGraph(const abstract::AbstractTuplePtr &tuple_arg,
122                                      const abstract::AbstractTuplePtr &position = nullptr) const;
123 
124   TailType tail_type_;
125   bool enable_tuple_grad_first_;
126   bool return_ids_;
127 };
128 using TailPtr = std::shared_ptr<Tail>;
129 
130 class MakeTupleGradient : public MetaFuncGraph {
131  public:
MakeTupleGradient(const std::string & name)132   explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {}
133   ~MakeTupleGradient() override = default;
134   MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
135   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
136   friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; }
137 };
138 using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
139 
140 class MakeListGradient : public MetaFuncGraph {
141  public:
MakeListGradient(const std::string & name)142   explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {}
143   ~MakeListGradient() override = default;
144   MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph)
145   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
146   friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; }
147 };
148 using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
149 
150 class MakeDictGradient : public MetaFuncGraph {
151  public:
MakeDictGradient(const std::string & name)152   explicit MakeDictGradient(const std::string &name) : MetaFuncGraph(name) {}
153   ~MakeDictGradient() override = default;
154   MS_DECLARE_PARENT(MakeDictGradient, MetaFuncGraph)
155   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
156   friend bool operator==(const MakeDictGradient &lhs, const MakeDictGradient &rhs) { return lhs.name_ == rhs.name_; }
157 };
158 using MakeDictGradientPtr = std::shared_ptr<MakeDictGradient>;
159 
160 class PyExecuteGradient : public MetaFuncGraph {
161  public:
PyExecuteGradient(const std::string & name)162   explicit PyExecuteGradient(const std::string &name) : MetaFuncGraph(name) {}
163   ~PyExecuteGradient() override = default;
164   MS_DECLARE_PARENT(PyExecuteGradient, MetaFuncGraph)
165   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
166   friend bool operator==(const PyExecuteGradient &lhs, const PyExecuteGradient &rhs) { return lhs.name_ == rhs.name_; }
167 };
168 using PyExecuteGradientPtr = std::shared_ptr<PyExecuteGradient>;
169 
170 class MutableGradient : public MetaFuncGraph {
171  public:
MutableGradient(const std::string & name)172   explicit MutableGradient(const std::string &name) : MetaFuncGraph(name) {}
173   ~MutableGradient() override = default;
174   MS_DECLARE_PARENT(MutableGradient, MetaFuncGraph)
175   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
176   friend bool operator==(const MutableGradient &lhs, const MutableGradient &rhs) { return lhs.name_ == rhs.name_; }
177 };
178 using MutableGradientPtr = std::shared_ptr<MutableGradient>;
179 
180 class GradOperation : public MetaFuncGraph {
181  public:
182   explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
183                          bool sens_param = false, bool get_by_position = false, bool has_aux = false,
184                          bool get_value = false, bool return_ids = false, bool merge_forward = false);
185   ~GradOperation() override = default;
186   MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
187 
188   FuncGraphPtr GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
189                        const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
190                        bool is_weights_none) const;
191 
192   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
193 
sens_param()194   bool sens_param() const { return sens_param_; }
195 
196   bool get_all_;
197   bool get_by_list_;
198   bool sens_param_;
199   bool get_by_position_;
200   bool has_aux_;
201   bool get_value_;
202   bool return_ids_;
203   bool merge_forward_;
204 
205  private:
206   void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
207                        const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad,
208                        bool is_weights_none) const;
209   CNodePtr SetNodeByParameter(const CNodePtr &grad, const FuncGraphPtr &fg) const;
210   AbstractBasePtr weight_value_;
211 };
212 using GradOperationPtr = std::shared_ptr<GradOperation>;
213 
214 class GradAux : public MetaFuncGraph {
215  public:
GradAux(const std::string & name)216   explicit GradAux(const std::string &name) : MetaFuncGraph(name) {}
217   ~GradAux() override = default;
218   MS_DECLARE_PARENT(GradAux, MetaFuncGraph);
219   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
220 };
221 using GradAuxPtr = std::shared_ptr<GradAux>;
222 
223 class TaylorOperation : public MetaFuncGraph {
224  public:
225   explicit TaylorOperation(const std::string &name);
226   ~TaylorOperation() override = default;
227   MS_DECLARE_PARENT(TaylorOperation, MetaFuncGraph);
228   FuncGraphPtr GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params) const;
229 
230   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
231 };
232 using TaylorOperationPtr = std::shared_ptr<TaylorOperation>;
233 
234 class TupleAdd : public MetaFuncGraph {
235  public:
TupleAdd(const std::string & name)236   explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {}
237   ~TupleAdd() override = default;
238   MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
239   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
240   friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; }
241 };
242 using TupleAddPtr = std::shared_ptr<TupleAdd>;
243 
244 class ListAdd : public MetaFuncGraph {
245  public:
ListAdd(const std::string & name)246   explicit ListAdd(const std::string &name) : MetaFuncGraph(name) {}
247   ~ListAdd() override = default;
248   MS_DECLARE_PARENT(ListAdd, MetaFuncGraph)
249   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
250   friend bool operator==(const ListAdd &lhs, const ListAdd &rhs) { return lhs.name_ == rhs.name_; }
251 };
252 using ListAddPtr = std::shared_ptr<ListAdd>;
253 
254 class SequenceSlice : public MetaFuncGraph {
255  public:
SequenceSlice(const std::string & name)256   explicit SequenceSlice(const std::string &name) : MetaFuncGraph(name) {}
257   ~SequenceSlice() override = default;
258   MS_DECLARE_PARENT(SequenceSlice, MetaFuncGraph)
259   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) final;
260   friend bool operator==(const SequenceSlice &lhs, const SequenceSlice &rhs) { return lhs.name_ == rhs.name_; }
261 
262  protected:
263   virtual void CheckArgs(const AbstractBasePtrList &args_abs_list) = 0;
264   virtual FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) = 0;
265   abstract::AbstractSequencePtr sequence_ = nullptr;
266   AbstractSlicePtr slice_ = nullptr;
267 };
268 
269 class SequenceSliceGetItem : public SequenceSlice {
270  public:
SequenceSliceGetItem(const std::string & name,const std::string & prim_name,const std::string & get_item_name)271   explicit SequenceSliceGetItem(const std::string &name, const std::string &prim_name, const std::string &get_item_name)
272       : SequenceSlice(name),
273         prim_(std::make_shared<Primitive>(prim_name)),
274         get_item_(std::make_shared<Primitive>(get_item_name)) {}
275   ~SequenceSliceGetItem() override = default;
276   MS_DECLARE_PARENT(SequenceSliceGetItem, MetaFuncGraph)
277   friend bool operator==(const SequenceSliceGetItem &lhs, const SequenceSliceGetItem &rhs) {
278     return lhs.name_ == rhs.name_;
279   }
280 
281  protected:
282   void CheckArgs(const AbstractBasePtrList &args_abs_list) override;
283   FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) override;
284 
285  private:
286   PrimitivePtr prim_;
287   PrimitivePtr get_item_;
288 };
289 
290 class ListSliceSetItem : public SequenceSlice {
291  public:
ListSliceSetItem(const std::string & name)292   explicit ListSliceSetItem(const std::string &name) : SequenceSlice(name) {}
293   ~ListSliceSetItem() override = default;
294   MS_DECLARE_PARENT(ListSliceSetItem, MetaFuncGraph)
295   friend bool operator==(const ListSliceSetItem &lhs, const ListSliceSetItem &rhs) { return lhs.name_ == rhs.name_; }
296 
297  protected:
298   void CheckArgs(const AbstractBasePtrList &args_abs_list) override;
299   FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) override;
300 
301  private:
302   void CheckAssignRange(int64_t start_index, int64_t stop_index, int64_t step_value);
303   AnfNodePtr GetAssignNode(const FuncGraphPtr &func_graph, const AnfNodePtr &assign_node, int64_t step_value);
304   AbstractListPtr value_list_ = nullptr;
305 };
306 
307 class TupleGetItemTensor : public MetaFuncGraph {
308  public:
TupleGetItemTensor(const std::string & name)309   explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
310   ~TupleGetItemTensor() override = default;
311   MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph)
312   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
313   friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) {
314     return lhs.name_ == rhs.name_;
315   }
316 };
317 using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>;
318 
319 class Shard : public MetaFuncGraph {
320  public:
Shard(const string & name)321   explicit Shard(const string &name) : MetaFuncGraph(name) {
322     signatures_ =
323       // def shard(func:read, weight_list:read, in_axes:read, out_axes:read, parameter_plan:read, device:read,
324       // level:read):
325       std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
326                               {"in_axes", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
327                               {"out_axes", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
328                               {"device", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
329                               {"level", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}});
330     kShardInputSize = signatures_.size();
331   }
332   ~Shard() override = default;
333   MS_DECLARE_PARENT(Shard, MetaFuncGraph)
334 
335   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
336 
337  private:
338   size_t kShardInputSize = 0;
339 };
340 
341 class VmapOperation : public MetaFuncGraph {
342  public:
343   explicit VmapOperation(const std::string &name);
344   ~VmapOperation() override = default;
345   MS_DECLARE_PARENT(VmapOperation, MetaFuncGraph)
346 
347   FuncGraphPtr GetVmap(const AnfNodePtr &vmap, int param_number) const;
348 
349   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
350 };
351 using VmapOperationPtr = std::shared_ptr<VmapOperation>;
352 
353 class ZerosLike : public MetaFuncGraph {
354  public:
355   explicit ZerosLike(const std::string &name, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
MetaFuncGraph(name)356       : MetaFuncGraph(name), fn_leaf_(fn_leaf) {}
357   ~ZerosLike() override = default;
358   MS_DECLARE_PARENT(ZerosLike, MetaFuncGraph)
359   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
360   friend bool operator==(const ZerosLike &lhs, const ZerosLike &rhs) { return lhs.name_ == rhs.name_; }
361 
362  private:
363   MultitypeFuncGraphPtr fn_leaf_;
364 };
365 using ZerosLikePtr = std::shared_ptr<ZerosLike>;
366 
367 class IterConverter : public MetaFuncGraph {
368  public:
IterConverter(const std::string & name)369   explicit IterConverter(const std::string &name) : MetaFuncGraph(name) {}
370   ~IterConverter() override = default;
371   MS_DECLARE_PARENT(IterConverter, MetaFuncGraph)
372   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
373   friend bool operator==(const IterConverter &lhs, const IterConverter &rhs) { return lhs.name_ == rhs.name_; }
374 };
375 using IterConverterPtr = std::shared_ptr<IterConverter>;
376 
377 class HasNext : public MetaFuncGraph {
378  public:
HasNext(const std::string & name)379   explicit HasNext(const std::string &name) : MetaFuncGraph(name) {}
380   ~HasNext() override = default;
381   MS_DECLARE_PARENT(HasNext, MetaFuncGraph)
382   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
383   friend bool operator==(const HasNext &lhs, const HasNext &rhs) { return lhs.name_ == rhs.name_; }
384 };
385 using HasNextPtr = std::shared_ptr<HasNext>;
386 
387 class Next : public MetaFuncGraph {
388  public:
Next(const std::string & name)389   explicit Next(const std::string &name) : MetaFuncGraph(name) {}
390   ~Next() override = default;
391   MS_DECLARE_PARENT(Next, MetaFuncGraph)
392   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
393   friend bool operator==(const Next &lhs, const Next &rhs) { return lhs.name_ == rhs.name_; }
394 };
395 using NextPtr = std::shared_ptr<Next>;
396 
397 class TupleFunc : public MetaFuncGraph {
398  public:
TupleFunc(const std::string & name)399   explicit TupleFunc(const std::string &name) : MetaFuncGraph(name) {}
400   ~TupleFunc() override = default;
401   MS_DECLARE_PARENT(TupleFunc, MetaFuncGraph)
402   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
403   friend bool operator==(const TupleFunc &lhs, const TupleFunc &rhs) { return lhs.name_ == rhs.name_; }
404 };
405 using TupleFuncPtr = std::shared_ptr<TupleFunc>;
406 
407 class ListFunc : public MetaFuncGraph {
408  public:
ListFunc(const std::string & name)409   explicit ListFunc(const std::string &name) : MetaFuncGraph(name) {}
410   ~ListFunc() override = default;
411   MS_DECLARE_PARENT(ListFunc, MetaFuncGraph)
412   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
413   friend bool operator==(const ListFunc &lhs, const ListFunc &rhs) { return lhs.name_ == rhs.name_; }
414 };
415 using ListFuncPtr = std::shared_ptr<ListFunc>;
416 }  // namespace prim
417 }  // namespace mindspore
418 
419 #endif  // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
420