• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
20 #define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
21 
22 #include <cstdint>
23 #include <memory>
24 #include <string>
25 
26 #include "abstract/abstract_value.h"
27 #include "abstract/analysis_context.h"
28 #include "ir/meta_func_graph.h"
29 
30 namespace mindspore {
31 namespace abstract {
32 /// \brief AbstractFuncAtom defines interface for abstract of atom function.
33 class MS_CORE_API AbstractFuncAtom : public AbstractFunction {
34  public:
35   /// \brief Constructor of AbstractFuncAtom.
36   AbstractFuncAtom() = default;
37 
38   /// \brief Destructor of AbstractFuncAtom.
39   ~AbstractFuncAtom() override = default;
MS_DECLARE_PARENT(AbstractFuncAtom,AbstractFunction)40   MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction)
41 
42   AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); }
43 
44   AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
45 
46   void Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const final;
47 
48   bool operator==(const AbstractFunction &other) const override;
49 
hash()50   std::size_t hash() const override { return tid(); }
51 };
52 
53 /// \brief AbstractFuncUnion defines interface for abstract of union function.
54 class MS_CORE_API AbstractFuncUnion final : public AbstractFunction {
55  public:
56   /// \brief Constructor AbstractFuncUnion from AbstractFuncAtom list.
57   ///
58   /// \param[in] func_list The AbstractFuncAtom list for AbstractFuncUnion.
59   explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list);
60 
61   /// \brief Constructor AbstractFuncUnion from two AbstractFunction.
62   ///
63   /// \param[in] first The first AbstractFunction for AbstractFuncUnion.
64   /// \param[in] second The second AbstractFunction for AbstractFuncUnion.
65   AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second);
66 
67   /// \brief Destructor for AbstractFunction.
68   ~AbstractFuncUnion() override = default;
69   MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction)
70 
71   std::string ToString() const override;
72 
73   std::string ToString(bool verbose) const override;
74 
GetUnique()75   AbstractFunctionPtr GetUnique() override { MS_LOG(INTERNAL_EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; }
76 
77   /// \brief Check whether the input AbstractFunction is in AbstractFuncUnion.
78   ///
79   /// \param[in] other The input AbstractFunction for check.
80   ///
81   /// \return Return true if other is in AbstractFuncUnion, otherwise return False.
82   bool IsSuperSet(const AbstractFunctionPtr &other);
83 
84   AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
85 
86   void Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const final;
87 
88   bool operator==(const AbstractFunction &other) const override;
89 
90   std::size_t hash() const override;
91 
Copy()92   AbstractFunctionPtr Copy() const override { MS_LOG(INTERNAL_EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; }
93 
94  private:
95   AbstractFuncAtomPtrList func_list_;
96 };
97 
98 /// \brief PrimitiveAbstractClosure defines interface for abstract of Primitive.
99 class MS_CORE_API PrimitiveAbstractClosure final : public AbstractFuncAtom {
100  public:
101   /// \brief Constructor of PrimitiveAbstractClosure
102   ///
103   /// \param[in] prim The primitive that this PrimitiveAbstractClosure corresponding to.
104   /// \param[in] tracking_node A Node identifies different uses of the prim.
105   explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_node = nullptr)
PrimitiveAbstractClosure(prim,ToTrackingId (tracking_node))106       : PrimitiveAbstractClosure(prim, ToTrackingId(tracking_node)) {}
107 
108   // For internal usage only, make it public so that make_shared can work on it.
PrimitiveAbstractClosure(const PrimitivePtr & prim,std::uintptr_t tracking_id)109   PrimitiveAbstractClosure(const PrimitivePtr &prim, std::uintptr_t tracking_id)
110       : prim_(prim), tracking_id_(tracking_id) {}
111 
112   /// \brief Destructor of PrimitiveAbstractClosure
113   ~PrimitiveAbstractClosure() override = default;
MS_DECLARE_PARENT(PrimitiveAbstractClosure,AbstractFuncAtom)114   MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom)
115 
116   /// \brief Get the Primitive that this PrimitiveAbstractClosure corresponding to.
117   ///
118   /// \return The Primitive that this PrimitiveAbstractClosure corresponding to.
119   const PrimitivePtr &prim() const { return prim_; }
120 
tracking_id()121   std::uintptr_t tracking_id() const override { return tracking_id_; }
122 
Copy()123   AbstractFunctionPtr Copy() const override { return std::make_shared<PrimitiveAbstractClosure>(prim_, tracking_id_); }
124 
CopyWithoutTrackingId()125   AbstractFunctionPtr CopyWithoutTrackingId() const override {
126     return std::make_shared<PrimitiveAbstractClosure>(prim_, 0);
127   }
128 
129   bool operator==(const AbstractFunction &other) const override;
130 
131   std::size_t hash() const override;
132 
ToString()133   std::string ToString() const override { return "PrimitiveAbstractClosure: " + prim_->name(); }
134 
135   std::string ToString(bool verbose) const override;
136 
RealBuildValue()137   ValuePtr RealBuildValue() const override { return prim_; }
138 
139  private:
140   PrimitivePtr prim_;
141   // To discriminate different usage of same primitive calls,
142   // store it as the memory address of the user node.
143   std::uintptr_t tracking_id_;
144 };
145 using PrimitiveAbstractClosurePtr = std::shared_ptr<PrimitiveAbstractClosure>;
146 
147 /// \brief FuncGraphAbstractClosure defines interface for abstract of FuncGraph.
148 class MS_CORE_API FuncGraphAbstractClosure final : public AbstractFuncAtom {
149  public:
150   /// \brief Constructor of FuncGraphAbstractClosure.
151   ///
152   /// \param[in] func_graph The function graph that this PrimitiveAbstractClosure corresponding to.
153   /// \param[in] context The context that func_graph corresponding to.
154   /// \param[in] tracking_node A Node identifies different uses of the func_graph.
155   FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
156                            const AnfNodePtr &tracking_node = nullptr, bool specialized = false)
FuncGraphAbstractClosure(func_graph,context,ToTrackingId (tracking_node),specialized)157       : FuncGraphAbstractClosure(func_graph, context, ToTrackingId(tracking_node), specialized) {}
158 
159   // For internal usage only, make it public so that make_shared can work on it.
FuncGraphAbstractClosure(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context,std::uintptr_t tracking_id,bool specialized)160   FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
161                            std::uintptr_t tracking_id, bool specialized)
162       : func_graph_(FuncGraphWeakPtr(func_graph)),
163         context_(context),
164         tracking_id_(tracking_id),
165         specialized_(specialized) {
166     MS_EXCEPTION_IF_NULL(func_graph);
167     MS_EXCEPTION_IF_NULL(context);
168   }
169 
170   /// \brief Destructor of FuncGraphAbstractClosure.
171   ~FuncGraphAbstractClosure() override = default;
MS_DECLARE_PARENT(FuncGraphAbstractClosure,AbstractFuncAtom)172   MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom)
173 
174   /// \brief Get the FuncGraph that this FuncGraphAbstractClosure corresponding to.
175   ///
176   /// \return The FuncGraph that this FuncGraphAbstractClosure corresponding to.
177   FuncGraphPtr func_graph() const { return func_graph_.lock(); }
178 
context()179   AnalysisContextPtr context() const override { return context_; }
180 
tracking_id()181   std::uintptr_t tracking_id() const override { return tracking_id_; }
182 
specialized()183   bool specialized() const { return specialized_; }
184 
Copy()185   AbstractFunctionPtr Copy() const override {
186     return std::make_shared<FuncGraphAbstractClosure>(func_graph(), context_, tracking_id_, specialized_);
187   }
188 
CopyWithoutTrackingId()189   AbstractFunctionPtr CopyWithoutTrackingId() const override {
190     return std::make_shared<FuncGraphAbstractClosure>(func_graph(), context_, 0, specialized_);
191   }
192 
193   bool operator==(const AbstractFunction &other) const override;
194 
195   std::size_t hash() const override;
196 
197   std::string ToString() const override;
198 
199   std::string ToString(bool verbose) const override;
200 
201   bool IsEqualExceptTrackingId(const FuncGraphAbstractClosure &other) const;
202 
203   std::size_t HashWithoutTrackingId() const;
204 
205  private:
206   FuncGraphWeakPtr func_graph_;
207   AnalysisContextPtr context_;
208   // To discriminate different usage of same graph by using this tracking_id,
209   // so different tracking_id will produce different FuncGraphAbstractClosure,
210   // different FuncGraphEvaluator.
211   // Especially useful for recursive func graph call, so it will not mess up
212   // the `context_` in FuncGraphEvaluator.
213   // store it as the memory address of the user node.
214   std::uintptr_t tracking_id_;
215   // If the func_graph_ member is the specialized func_graph_ in current IR or
216   // it's a old func_graph of IR before renormalized.
217   bool specialized_{false};
218 };
219 using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>;
220 
221 /// \brief MetaFuncGraphAbstractClosure defines interface for abstract of MetaFuncGraph.
222 class MS_CORE_API MetaFuncGraphAbstractClosure final : public AbstractFuncAtom {
223  public:
224   /// \brief Constructor of FuncGraphAbstractClosure.
225   ///
226   /// \param[in] meta_func_graph The function graph that this MetaFuncGraphAbstractClosure corresponding to.
227   /// \param[in] tracking_node A Node identifies different uses of the meta_func_graph.
228   /// \param[in] scope The scope to which the tracking_id belong to.
229   explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph,
230                                         const AnfNodePtr &tracking_node = nullptr,
231                                         const ScopePtr &scope = kDefaultScope)
MetaFuncGraphAbstractClosure(meta_func_graph,ToTrackingId (tracking_node),scope)232       : MetaFuncGraphAbstractClosure(meta_func_graph, ToTrackingId(tracking_node), scope) {}
233 
234   // For internal usage only, make it public so that make_shared can work on it.
MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr & meta_func_graph,std::uintptr_t tracking_id,const ScopePtr & scope)235   MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, std::uintptr_t tracking_id,
236                                const ScopePtr &scope)
237       : meta_func_graph_(meta_func_graph), tracking_id_(tracking_id), scope_(scope) {}
238 
239   /// \brief Destructor of MetaFuncGraphAbstractClosure.
240   ~MetaFuncGraphAbstractClosure() override = default;
MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure,AbstractFuncAtom)241   MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom)
242 
243   /// \brief Get the MetaFuncGraph that this MetaFuncGraphAbstractClosure corresponding to.
244   ///
245   /// \return The MetaFuncGraph that this MetaFuncGraphAbstractClosure corresponding to.
246   const MetaFuncGraphPtr &meta_func_graph() const { return meta_func_graph_; }
247 
context()248   AnalysisContextPtr context() const override { return AnalysisContext::DummyContext(); }
249 
250   /// \brief Get the Scope that this MetaFuncGraphAbstractClosure corresponding to.
251   ///
252   /// \return The Scope that this MetaFuncGraphAbstractClosure corresponding to.
GetScope()253   const ScopePtr &GetScope() const { return scope_; }
254 
tracking_id()255   std::uintptr_t tracking_id() const override { return tracking_id_; }
256 
Copy()257   AbstractFunctionPtr Copy() const override {
258     return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, tracking_id_, kDefaultScope);
259   }
260 
CopyWithoutTrackingId()261   AbstractFunctionPtr CopyWithoutTrackingId() const override {
262     return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, 0, kDefaultScope);
263   }
264 
265   bool operator==(const AbstractFunction &other) const override;
266 
267   std::size_t hash() const override;
268 
269   std::string ToString() const override;
270 
271  private:
272   MetaFuncGraphPtr meta_func_graph_;
273   // Refer the comment in FuncGraphAbstractClosure;
274   // Store it as memory address of the user node.
275   std::uintptr_t tracking_id_;
276   ScopePtr scope_;
277 };
278 using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>;
279 
280 /// \brief PartialAbstractClosure defines the abstract AbstractFuncAtom interface provided by some args in advance.
281 class MS_CORE_API PartialAbstractClosure final : public AbstractFuncAtom {
282  public:
283   /// \brief Constructor of PartialAbstractClosure.
284   ///
285   /// \param[in] fn The AbstractFuncAtom this PartialAbstractClosure corresponding to.
286   /// \param[in] args_abs_list The first few parameters provided for fn in advance.
287   /// \param[in] node The CNode which this PartialAbstractClosure evaluated from.
288   PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_abs_list,
289                          const AnfNodePtr &node = nullptr)
fn_(fn)290       : fn_(fn), args_abs_list_(args_abs_list), node_(AnfNodePtr(node)) {}
291 
292   /// \brief Destructor of PartialAbstractClosure.
293   ~PartialAbstractClosure() override = default;
MS_DECLARE_PARENT(PartialAbstractClosure,AbstractFuncAtom)294   MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
295 
296   /// \brief Get the AbstractFuncAtom that this PartialAbstractClosure corresponding to.
297   ///
298   /// \return The AbstractFuncAtom that this PartialAbstractClosure corresponding to.
299   AbstractFuncAtomPtr fn() { return fn_; }
300 
301   /// \brief Set the AbstractFuncAtom that this PartialAbstractClosure corresponding to.
302   ///
303   /// \param[in] fn The AbstractFuncAtom that this PartialAbstractClosure corresponding to.
set_fn(const AbstractFuncAtomPtr & fn)304   void set_fn(const AbstractFuncAtomPtr &fn) { fn_ = fn; }
305 
306   /// \brief Get the pre-provided arguments.
307   ///
308   /// \return The pre-provided arguments.
args()309   const AbstractBasePtrList &args() const { return args_abs_list_; }
310 
311   /// \brief Get the CNode this PartialAbstractClosure evaluated from.
312   ///
313   /// \return The CNode this PartialAbstractClosure evaluated from.
node()314   AnfNodePtr node() const { return node_.lock(); }
315 
316   /// \brief Set the CNode this PartialAbstractClosure evaluated from.
317   ///
318   /// \param[in] node The CNode this PartialAbstractClosure evaluated from.
set_node(const AnfNodePtr & node)319   void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
320 
321   /// \brief Get whether the args need to be appended to the end.
322   ///
323   /// \return Whether the args need to be appended to the end.
need_append_to_end()324   bool need_append_to_end() const { return need_append_to_end_; }
325 
326   /// \brief Set whether the args need to be appended to the end.
327   ///
328   /// \param[in] flag Whether the args need to be appended to the end.
set_need_append_to_end(bool flag)329   void set_need_append_to_end(bool flag) { need_append_to_end_ = flag; }
330 
Copy()331   AbstractFunctionPtr Copy() const override {
332     auto abs = std::make_shared<PartialAbstractClosure>(fn_, args_abs_list_, node_.lock());
333     abs->set_need_append_to_end(need_append_to_end_);
334     return abs;
335   }
336 
337   bool operator==(const AbstractFunction &other) const override;
338 
339   std::size_t hash() const override;
340 
341   std::string ToString() const override;
342 
343   std::string ToString(bool verbose) const override;
344 
345  protected:
RealBuildValue()346   ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
347 
348  private:
349   AbstractFuncAtomPtr fn_;
350   AbstractBasePtrList args_abs_list_;
351   // The ANFNode which this PartialAbstractClosure evaluated from.
352   AnfNodeWeakPtr node_;
353   bool need_append_to_end_{false};
354 };
355 using PartialAbstractClosurePtr = std::shared_ptr<PartialAbstractClosure>;
356 
357 /// \brief JTransformedAbstractClosure defines interface for abstract of Function
358 /// transformed through the application of J.
359 class MS_CORE_API JTransformedAbstractClosure final : public AbstractFuncAtom {
360  public:
361   /// \brief Constructor of JTransformedAbstractClosure
362   ///
363   /// \param[in] fn The AbstractFuncAtom transformed through the application of J.
JTransformedAbstractClosure(const AbstractFuncAtomPtr & fn)364   explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
365 
366   /// \brief Destructor of JTransformedAbstractClosure
367   ~JTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(JTransformedAbstractClosure,AbstractFuncAtom)368   MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom)
369 
370   /// \brief Get the AbstractFuncAtom JTransformedAbstractClosure corresponding to.
371   ///
372   /// \return The AbstractFuncAtom JTransformedAbstractClosure corresponding to.
373   const AbstractFuncAtomPtr &fn() const { return fn_; }
374 
Copy()375   AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); }
376 
377   bool operator==(const AbstractFunction &other) const override;
378 
379   std::size_t hash() const override;
380 
ToString()381   std::string ToString() const override { return "J(" + fn_->ToString() + ")"; }
382 
383  private:
384   AbstractFuncAtomPtr fn_;
385 };
386 
387 /// \brief TaylorTransformedAbstractClosure defines interface for abstract of Function
388 /// transformed through the application of Taylor.
389 class MS_CORE_API TaylorTransformedAbstractClosure final : public AbstractFuncAtom {
390  public:
391   /// \brief Constructor of TaylorTransformedAbstractClosure
392   ///
393   /// \param[in] fn The AbstractFuncAtom transformed through the application of Taylor.
TaylorTransformedAbstractClosure(const AbstractFuncAtomPtr & fn)394   explicit TaylorTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
395 
396   /// \brief Destructor of TaylorTransformedAbstractClosure
397   ~TaylorTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(TaylorTransformedAbstractClosure,AbstractFuncAtom)398   MS_DECLARE_PARENT(TaylorTransformedAbstractClosure, AbstractFuncAtom)
399 
400   /// \brief Get the AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to.
401   ///
402   /// \return The AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to.
403   const AbstractFuncAtomPtr &fn() const { return fn_; }
404 
Copy()405   AbstractFunctionPtr Copy() const override { return std::make_shared<TaylorTransformedAbstractClosure>(fn_); }
406 
407   bool operator==(const AbstractFunction &other) const override;
408 
409   std::size_t hash() const override;
410 
ToString()411   std::string ToString() const override { return "Taylor(" + fn_->ToString() + ")"; }
412 
413  private:
414   AbstractFuncAtomPtr fn_;
415 };
416 
417 /// \brief ShardTransformedAbstractClosure defines interface for abstract of Function
418 /// transformed through the application of Shard.
419 class MS_CORE_API ShardTransformedAbstractClosure final : public AbstractFuncAtom {
420  public:
421   /// \brief Constructor of ShardTransformedAbstractClosure
422   ///
423   /// \param[in] fn The AbstractFuncAtom transformed through the application of Shard.
ShardTransformedAbstractClosure(const AbstractFuncAtomPtr & fn)424   explicit ShardTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
425 
426   /// \brief Destructor of ShardTransformedAbstractClosure
427   ~ShardTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(ShardTransformedAbstractClosure,AbstractFuncAtom)428   MS_DECLARE_PARENT(ShardTransformedAbstractClosure, AbstractFuncAtom)
429 
430   /// \brief Get the AbstractFuncAtom ShardTransformedAbstractClosure corresponding to.
431   ///
432   /// \return The AbstractFuncAtom ShardTransformedAbstractClosure corresponding to.
433   const AbstractFuncAtomPtr &fn() const { return fn_; }
434 
Copy()435   AbstractFunctionPtr Copy() const override { return std::make_shared<ShardTransformedAbstractClosure>(fn_); }
436 
437   bool operator==(const AbstractFunction &other) const override;
438 
439   std::size_t hash() const override;
440 
ToString()441   std::string ToString() const override { return "Shard(" + fn_->ToString() + ")"; }
442 
443  private:
444   AbstractFuncAtomPtr fn_;
445 };
446 
447 /// \brief VmapTransformedAbstractClosure defines interface for abstract of Function
448 /// transformed through the application of Vmap.
449 class MS_CORE_API VmapTransformedAbstractClosure final : public AbstractFuncAtom {
450  public:
451   /// \brief Constructor of VmapTransformedAbstractClosure
452   ///
453   /// \param[in] fn The AbstractFuncAtom transformed through the application of Vmap.
VmapTransformedAbstractClosure(const AbstractFuncAtomPtr & fn,const ValuePtr & in_axes,const ValuePtr & out_axes,size_t cell_size)454   explicit VmapTransformedAbstractClosure(const AbstractFuncAtomPtr &fn, const ValuePtr &in_axes,
455                                           const ValuePtr &out_axes, size_t cell_size)
456       : fn_(fn), in_axes_(in_axes), out_axes_(out_axes), cell_size_(cell_size) {}
457 
458   /// \brief Destructor of VmapTransformedAbstractClosure
459   ~VmapTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(VmapTransformedAbstractClosure,AbstractFuncAtom)460   MS_DECLARE_PARENT(VmapTransformedAbstractClosure, AbstractFuncAtom)
461 
462   /// \brief Get the AbstractFuncAtom VmapTransformedAbstractClosure corresponding to.
463   ///
464   /// \return The AbstractFuncAtom VmapTransformedAbstractClosure corresponding to.
465   const AbstractFuncAtomPtr &fn() const { return fn_; }
466 
in_axes()467   const ValuePtr &in_axes() const { return in_axes_; }
468 
out_axes()469   const ValuePtr &out_axes() const { return out_axes_; }
470 
cell_size()471   size_t cell_size() const { return cell_size_; }
472 
Copy()473   AbstractFunctionPtr Copy() const override {
474     return std::make_shared<VmapTransformedAbstractClosure>(fn_, in_axes_, out_axes_, cell_size_);
475   }
476 
477   bool operator==(const AbstractFunction &other) const override;
478 
479   std::size_t hash() const override;
480 
ToString()481   std::string ToString() const override { return "Vmap(" + fn_->ToString() + ")"; }
482 
483  private:
484   AbstractFuncAtomPtr fn_;
485   ValuePtr in_axes_;
486   ValuePtr out_axes_;
487   size_t cell_size_;
488 };
489 
490 /// \brief VirtualAbstractClosure defines interface for function with an explicitly
491 /// fixed type signature.
492 class MS_CORE_API VirtualAbstractClosure final : public AbstractFuncAtom {
493  public:
494   /// \brief Constructor of VirtualAbstractClosure.
495   ///
496   /// \param[in] args_abs_list The abstract values of the arguments to the function.
497   /// \param[in] output_spec The abstract value of output.
VirtualAbstractClosure(const AbstractBasePtrList & args_abs_list,const AbstractBasePtr & output_spec)498   VirtualAbstractClosure(const AbstractBasePtrList &args_abs_list, const AbstractBasePtr &output_spec)
499       : args_abs_list_(args_abs_list), output_(output_spec) {}
500 
501   /// \brief Constructor of VirtualAbstractClosure.
502   ///
503   /// \param[in] args_abs The abstract value of argument to the function.
504   /// \param[in] output_spec The abstract value of output.
VirtualAbstractClosure(const AbstractBasePtr & args_abs,const AbstractBasePtr & output_spec)505   VirtualAbstractClosure(const AbstractBasePtr &args_abs, const AbstractBasePtr &output_spec)
506       : args_abs_list_({args_abs}), output_(output_spec) {}
507 
508   /// \brief Destructor of VirtualAbstractClosure.
509   ~VirtualAbstractClosure() override = default;
MS_DECLARE_PARENT(VirtualAbstractClosure,AbstractFuncAtom)510   MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom)
511 
512   /// \brief Get the abstract values of arguments.
513   ///
514   /// \return The abstract values of arguments.
515   const AbstractBasePtrList &args_abs_list() const { return args_abs_list_; }
516 
517   /// \brief Get the abstract value of output.
518   ///
519   /// \return The abstract value of output.
output()520   const AbstractBasePtr &output() const { return output_; }
521 
Copy()522   AbstractFunctionPtr Copy() const override {
523     return std::make_shared<VirtualAbstractClosure>(args_abs_list_, output_);
524   }
525 
526   bool operator==(const AbstractFunction &other) const override;
527 
528   std::size_t hash() const override;
529 
530   std::string ToString() const override;
531 
532  private:
533   AbstractBasePtrList args_abs_list_;
534   AbstractBasePtr output_;
535 };
536 using VirtualAbstractClosurePtr = std::shared_ptr<VirtualAbstractClosure>;
537 
538 /// \brief TypedPrimitiveAbstractClosure defines interface for Primitive with an explicitly
539 /// fixed type signature.
540 class MS_CORE_API TypedPrimitiveAbstractClosure final : public AbstractFuncAtom {
541  public:
542   /// \brief Constructor of TypedPrimitiveAbstractClosure.
543   ///
544   /// \param[in] prim The Primitive with an explicitly fixed type signature.
545   /// \param[in] args_abs_list The abstract values of arguments to the Primitive.
546   /// \param[in] output_spec The abstract value of output.
TypedPrimitiveAbstractClosure(const PrimitivePtr prim,const AbstractBasePtrList & args_abs_list,const AbstractBasePtr & output_spec)547   TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_abs_list,
548                                 const AbstractBasePtr &output_spec)
549       : prim_(prim), args_abs_list_(args_abs_list), output_(output_spec) {}
550 
551   /// \brief Destructor of TypedPrimitiveAbstractClosure.
552   ~TypedPrimitiveAbstractClosure() override = default;
MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure,AbstractFuncAtom)553   MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom)
554 
555   /// \brief Get the Primitive that this TypedPrimitiveAbstractClosure corresponding to.
556   ///
557   /// \return The Primitive that this TypedPrimitiveAbstractClosure corresponding to.
558   const PrimitivePtr &prim() const { return prim_; }
559 
560   /// \brief Get the abstract values of arguments this TypedPrimitiveAbstractClosure corresponding to.
561   ///
562   /// \return The abstract values of arguments this TypedPrimitiveAbstractClosure corresponding to.
args_abs_list()563   const AbstractBasePtrList &args_abs_list() const { return args_abs_list_; }
564 
565   /// \brief Get the abstract value of output this TypedPrimitiveAbstractClosure corresponding to.
566   ///
567   /// \return The abstract value of output this TypedPrimitiveAbstractClosure corresponding to.
output()568   const AbstractBasePtr &output() const { return output_; }
569 
Copy()570   AbstractFunctionPtr Copy() const override {
571     return std::make_shared<TypedPrimitiveAbstractClosure>(prim_, args_abs_list_, output_);
572   }
573 
574   bool operator==(const AbstractFunction &other) const override;
575 
576   std::size_t hash() const override;
577 
578   std::string ToString() const override;
579 
580  private:
581   PrimitivePtr prim_;
582   AbstractBasePtrList args_abs_list_;
583   AbstractBasePtr output_;
584 };
585 
586 /// \brief Hash operator for AbstractFunction.
587 struct MS_CORE_API AbstractFunctionHasher {
588   /// \brief Implementation of hash operation.
589   ///
590   /// \param[in] t The AbstractFunction needs to hash.
591   ///
592   /// \return The hash result.
operatorAbstractFunctionHasher593   std::size_t operator()(const AbstractFunctionPtr &t) const {
594     std::size_t hash = t->hash();
595     return hash;
596   }
597 };
598 
599 /// \brief Equal operator for AbstractFunction.
600 struct MS_CORE_API AbstractFunctionEqual {
601   /// \brief Implementation of Equal operation.
602   ///
603   /// \param[in] lhs The left AbstractFunction for compare.
604   /// \param[in] rhs The right AbstractFunction for compare.
605   ///
606   /// \return Return True if the comparison result is equal, otherwise return False.
operatorAbstractFunctionEqual607   bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; }
608 };
609 }  // namespace abstract
610 }  // namespace mindspore
611 #endif  // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
612