• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_TRANSFORM_H_
20 #define MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_TRANSFORM_H_
21 #include "ir/anf.h"
22 
23 namespace mindspore {
24 // ANF transform class.
25 // Either a primitive or a func_graph.
26 class MS_CORE_API FuncGraphTransform {
27  public:
28   enum Type { kGtPrimitive, kGtFuncGraph };
29 
30   explicit FuncGraphTransform(const PrimitivePtr &prim, const FuncGraphPtr &func_graph = nullptr,
31                               const CNodePtr &primal_cnode = nullptr)
prim_(prim)32       : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)), primal_cnode_(primal_cnode) {}
33 
34   explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_,
35                               const CNodePtr &primal_cnode = nullptr);
36 
37   explicit FuncGraphTransform(const CNodePtr &primal_cnode, const PrimitivePtr &prim = func_graph_prim_,
38                               const FuncGraphPtr &func_graph = nullptr)
prim_(prim)39       : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)), primal_cnode_(primal_cnode) {}
40 
FuncGraphTransform(const FuncGraphTransform & t)41   FuncGraphTransform(const FuncGraphTransform &t)
42       : prim_(t.prim_), func_graph_(t.func_graph_), primal_cnode_(t.primal_cnode_) {}
43 
44   ~FuncGraphTransform() = default;
45 
type()46   Type type() const {
47     if (IsFuncGraph()) {
48       return kGtFuncGraph;
49     } else {
50       return kGtPrimitive;
51     }
52   }
53 
IsPrimitive()54   bool IsPrimitive() const { return (func_graph_.lock() == nullptr); }
IsFuncGraph()55   bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); }
func_graph()56   FuncGraphPtr func_graph() const { return func_graph_.lock(); }
primitive()57   PrimitivePtr primitive() const { return prim_; }
primal_cnode()58   CNodePtr primal_cnode() const { return primal_cnode_; }
59 
60   FuncGraphTransform &operator=(const FuncGraphTransform &t) {
61     if (this != &t) {
62       prim_ = t.prim_;
63       func_graph_ = t.func_graph_;
64       primal_cnode_ = t.primal_cnode_;
65     }
66     return *this;
67   }
68 
69  private:
70   PrimitivePtr prim_;
71   // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here.
72   // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in
73   // FPropRemapper::FinalizeGraph().
74   FuncGraphWeakPtr func_graph_;
75   static const PrimitivePtr func_graph_prim_;
76   CNodePtr primal_cnode_;
77 };
78 }  // namespace mindspore
79 #endif  // MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_TRANSFORM_H_
80