• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
20 #define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
21 
22 #include <memory>
23 #include <string>
24 
25 #include "abstract/abstract_value.h"
26 #include "abstract/analysis_context.h"
27 #include "ir/meta_func_graph.h"
28 
29 namespace mindspore {
30 namespace abstract {
31 class MS_CORE_API AbstractFuncAtom : public AbstractFunction {
32  public:
33   AbstractFuncAtom() = default;
34   ~AbstractFuncAtom() override = default;
MS_DECLARE_PARENT(AbstractFuncAtom,AbstractFunction)35   MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction)
36 
37   AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); }
38   AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
39   void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
40   bool operator==(const AbstractFunction &other) const override;
41 
hash()42   std::size_t hash() const override { return tid(); }
43 };
44 
45 class MS_CORE_API AbstractFuncUnion : public AbstractFunction {
46  public:
47   explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list);
48   AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second);
49   ~AbstractFuncUnion() override = default;
50   MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction)
51 
52   std::string ToString() const override;
53 
GetUnique()54   AbstractFunctionPtr GetUnique() override {
55     MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion";
56     AbstractFunctionPtr result;
57     return result;
58   }
59   bool IsSuperSet(const AbstractFunctionPtr &other);
60   AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
61   void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
62   bool operator==(const AbstractFunction &other) const override;
63   std::size_t hash() const override;
Copy()64   AbstractFunctionPtr Copy() const override {
65     MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion";
66     AbstractFunctionPtr result;
67     return result;
68   }
69 
70  private:
71   AbstractFuncAtomPtrList func_list_;
72 };
73 
74 class MS_CORE_API PrimitiveAbstractClosure : public AbstractFuncAtom {
75  public:
76   // Represents a Primitive.
77   // prim: The primitive
78   // tracking_id: Identifies different uses of the same primitive.
79   explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_id = nullptr)
prim_(prim)80       : prim_(prim), tracking_id_(AnfNodeWeakPtr(tracking_id)) {}
81   ~PrimitiveAbstractClosure() override = default;
MS_DECLARE_PARENT(PrimitiveAbstractClosure,AbstractFuncAtom)82   MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom)
83 
84   PrimitivePtr prim() { return prim_; }
85 
tracking_id()86   AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
87 
set_tracking_id(AnfNodePtr node)88   void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); }
89 
Copy()90   AbstractFunctionPtr Copy() const override { return std::make_shared<PrimitiveAbstractClosure>(prim_, tracking_id()); }
91 
92   bool operator==(const AbstractFunction &other) const override;
93   std::size_t hash() const override;
94 
ToString()95   std::string ToString() const override { return "Prim: " + prim_->name(); }
96 
RealBuildValue()97   ValuePtr RealBuildValue() const override { return prim_; }
98 
99  private:
100   PrimitivePtr prim_;
101   // store it as weak_ptr to break reference cycle.
102   // one reference cycle example is Graph::set_output() input0 local variable.
103   AnfNodeWeakPtr tracking_id_;
104 };
105 using PrimitiveAbstractClosurePtr = std::shared_ptr<PrimitiveAbstractClosure>;
106 
107 class MS_CORE_API FuncGraphAbstractClosure : public AbstractFuncAtom {
108  public:
109   // Represents a Graph in a certain Context.
110   // context: The context, or Context.empty()
111   FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
112                            const AnfNodePtr &tracking_id = nullptr)
func_graph_(func_graph)113       : func_graph_(func_graph), context_(context), tracking_id_(AnfNodeWeakPtr(tracking_id)) {
114     MS_EXCEPTION_IF_NULL(func_graph);
115     MS_EXCEPTION_IF_NULL(context);
116   }
117   ~FuncGraphAbstractClosure() override = default;
MS_DECLARE_PARENT(FuncGraphAbstractClosure,AbstractFuncAtom)118   MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom)
119 
120   FuncGraphPtr func_graph() { return func_graph_; }
121 
context()122   AnalysisContextPtr context() const override { return context_; }
123 
tracking_id()124   AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
125 
set_tracking_id(AnfNodePtr node)126   void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); }
127 
Copy()128   AbstractFunctionPtr Copy() const override {
129     return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_, tracking_id());
130   }
131 
132   bool operator==(const AbstractFunction &other) const override;
133   std::size_t hash() const override;
134 
135   std::string ToString() const override;
136 
137  private:
138   FuncGraphPtr func_graph_;
139   AnalysisContextPtr context_;
140   // To discriminate different usage of same graph by using this tracking_id,
141   // so different tracking_id will produce different FuncGraphAbstractClosure,
142   // different FuncGraphEvaluator.
143   // Espcecially useful for recursive func graph call, so it will not mess up
144   // the `context_` in FuncGraphEvaluator.
145   // Notes: Be careful to use nullptr for this variable.
146   // store it as weak_ptr to break reference cycle.
147   AnfNodeWeakPtr tracking_id_;
148 };
149 using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>;
150 
151 class MS_CORE_API MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
152  public:
153   explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph,
154                                         const AnfNodePtr &tracking_id = nullptr, const ScopePtr &scope = kDefaultScope)
meta_func_graph_(meta_func_graph)155       : meta_func_graph_(meta_func_graph), tracking_id_(AnfNodeWeakPtr(tracking_id)), scope_(scope) {}
156   ~MetaFuncGraphAbstractClosure() override = default;
MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure,AbstractFuncAtom)157   MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom)
158 
159   MetaFuncGraphPtr meta_func_graph() { return meta_func_graph_; }
160 
context()161   AnalysisContextPtr context() const override { return kDummyAnalysisContext; }
162 
GetScope()163   ScopePtr GetScope() { return scope_; }
164 
tracking_id()165   AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
166 
Copy()167   AbstractFunctionPtr Copy() const override {
168     return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, tracking_id());
169   }
170   bool operator==(const AbstractFunction &other) const override;
171   std::size_t hash() const override;
172 
173   std::string ToString() const override;
174 
175  private:
176   MetaFuncGraphPtr meta_func_graph_;
177   // refer the comment in FuncGraphAbstractClosure;
178   // store it as weak_ptr to break reference cycle.
179   AnfNodeWeakPtr tracking_id_;
180   ScopePtr scope_;
181 };
182 using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>;
183 
184 class MS_CORE_API PartialAbstractClosure : public AbstractFuncAtom {
185  public:
186   // Represents a partial application.
187   // args_spec_list: The first few arguments of that function
188   PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list,
189                          const AnfNodePtr &node = nullptr)
fn_(fn)190       : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {}
191   ~PartialAbstractClosure() override = default;
MS_DECLARE_PARENT(PartialAbstractClosure,AbstractFuncAtom)192   MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
193 
194   AbstractFunctionPtr fn() { return fn_; }
args()195   const AbstractBasePtrList &args() { return args_spec_list_; }
RealBuildValue()196   ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
node()197   AnfNodePtr node() { return node_.lock(); }
set_node(const AnfNodePtr & node)198   void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
Copy()199   AbstractFunctionPtr Copy() const override {
200     return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_, node_.lock());
201   }
202   bool operator==(const AbstractFunction &other) const override;
203   std::size_t hash() const override;
204 
205   std::string ToString() const override;
206 
207  private:
208   AbstractFuncAtomPtr fn_;
209   AbstractBasePtrList args_spec_list_;
210   // The CNode which this PartialAbstractClosure evaluated from.
211   AnfNodeWeakPtr node_;
212 };
213 using PartialAbstractClosurePtr = std::shared_ptr<PartialAbstractClosure>;
214 
215 class MS_CORE_API JTransformedAbstractClosure : public AbstractFuncAtom {
216  public:
217   // Represents a Function transformed through the application of J.
JTransformedAbstractClosure(const AbstractFuncAtomPtr & fn)218   explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
219   ~JTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(JTransformedAbstractClosure,AbstractFuncAtom)220   MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom)
221 
222   AbstractFuncAtomPtr fn() { return fn_; }
Copy()223   AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); }
224   bool operator==(const AbstractFunction &other) const override;
225   std::size_t hash() const override;
226 
ToString()227   std::string ToString() const override { return "J(" + fn_->ToString() + ")"; }
228 
229  private:
230   AbstractFuncAtomPtr fn_;
231 };
232 
233 class MS_CORE_API VirtualAbstractClosure : public AbstractFuncAtom {
234  public:
235   // Represents some function with an explicitly fixed type signature.
236   // args_spec_list: The arguments as abstract value given to the function
237   // output: The output which is abstract value.
VirtualAbstractClosure(const AbstractBasePtrList & args_spec_list,const AbstractBasePtr & output_spec)238   VirtualAbstractClosure(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output_spec)
239       : args_spec_list_(args_spec_list), output_(output_spec) {}
VirtualAbstractClosure(const AbstractBasePtr & args_spec,const AbstractBasePtr & output_spec)240   VirtualAbstractClosure(const AbstractBasePtr &args_spec, const AbstractBasePtr &output_spec)
241       : args_spec_list_({args_spec}), output_(output_spec) {}
242   ~VirtualAbstractClosure() override = default;
MS_DECLARE_PARENT(VirtualAbstractClosure,AbstractFuncAtom)243   MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom)
244 
245   AbstractBasePtrList args_spec_list() { return args_spec_list_; }
246 
output()247   AbstractBasePtr output() { return output_; }
Copy()248   AbstractFunctionPtr Copy() const override {
249     return std::make_shared<VirtualAbstractClosure>(args_spec_list_, output_);
250   }
251   bool operator==(const AbstractFunction &other) const override;
252   std::size_t hash() const override;
253 
254   std::string ToString() const override;
255 
256  private:
257   AbstractBasePtrList args_spec_list_;
258   AbstractBasePtr output_;
259 };
260 using VirtualAbstractClosurePtr = std::shared_ptr<VirtualAbstractClosure>;
261 
262 class MS_CORE_API TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
263  public:
264   // Represents a Primitive with an explicitly fixed type signature.
265   // args_spec_list: The arguments as abstract value given to the Primitive
266   // output: The output which is abstract value.
TypedPrimitiveAbstractClosure(const PrimitivePtr prim,const AbstractBasePtrList & args_spec_list,const AbstractBasePtr & output_spec)267   TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_spec_list,
268                                 const AbstractBasePtr &output_spec)
269       : prim_(prim), args_spec_list_(args_spec_list), output_(output_spec) {}
270   ~TypedPrimitiveAbstractClosure() override = default;
MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure,AbstractFuncAtom)271   MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom)
272 
273   PrimitivePtr prim() { return prim_; }
args_spec_list()274   AbstractBasePtrList args_spec_list() { return args_spec_list_; }
output()275   AbstractBasePtr output() { return output_; }
Copy()276   AbstractFunctionPtr Copy() const override {
277     return std::make_shared<TypedPrimitiveAbstractClosure>(prim_, args_spec_list_, output_);
278   }
279   bool operator==(const AbstractFunction &other) const override;
280   std::size_t hash() const override;
281 
282   std::string ToString() const override;
283 
284  private:
285   PrimitivePtr prim_;
286   AbstractBasePtrList args_spec_list_;
287   AbstractBasePtr output_;
288 };
289 
290 class PyInterpretAbstractClosure : public AbstractFuncAtom {
291  public:
292   PyInterpretAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list,
293                              const AnfNodePtr &node = nullptr)
fn_(fn)294       : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {}
295   ~PyInterpretAbstractClosure() override = default;
MS_DECLARE_PARENT(PyInterpretAbstractClosure,AbstractFuncAtom)296   MS_DECLARE_PARENT(PyInterpretAbstractClosure, AbstractFuncAtom)
297 
298   AbstractFunctionPtr fn() { return fn_; }
args()299   AbstractBasePtrList args() { return args_spec_list_; }
RealBuildValue()300   ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
node()301   AnfNodePtr node() { return node_.lock(); }
set_node(const AnfNodePtr & node)302   void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
Copy()303   AbstractFunctionPtr Copy() const override {
304     return std::make_shared<PyInterpretAbstractClosure>(fn_, args_spec_list_, node_.lock());
305   }
306   bool operator==(const AbstractFunction &other) const override;
307   std::size_t hash() const override;
308 
309   std::string ToString() const override;
310 
311  private:
312   AbstractFuncAtomPtr fn_;
313   AbstractBasePtrList args_spec_list_;
314   AnfNodeWeakPtr node_;
315 };
316 using PyInterpretAbstractClosurePtr = std::shared_ptr<PyInterpretAbstractClosure>;
317 
318 // Represents a function that can't be called.
319 class MS_CORE_API DummyAbstractClosure : public AbstractFuncAtom {
320  public:
321   DummyAbstractClosure() = default;
322   ~DummyAbstractClosure() override = default;
MS_DECLARE_PARENT(DummyAbstractClosure,AbstractFuncAtom)323   MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom)
324 
325   AbstractFunctionPtr Copy() const override { return std::make_shared<DummyAbstractClosure>(); }
326   bool operator==(const AbstractFunction &other) const override;
327 
ToString()328   std::string ToString() const override { return "DummyAbstractClosure()"; }
329 };
330 
331 struct MS_CORE_API AbstractFunctionHasher {
operatorAbstractFunctionHasher332   std::size_t operator()(const AbstractFunctionPtr &t) const {
333     std::size_t hash = t->hash();
334     return hash;
335   }
336 };
337 
338 struct MS_CORE_API AbstractFunctionEqual {
operatorAbstractFunctionEqual339   bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; }
340 };
341 }  // namespace abstract
342 }  // namespace mindspore
343 #endif  // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
344