1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_TRANSFORM_H_ 20 #define MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_TRANSFORM_H_ 21 #include "ir/anf.h" 22 23 namespace mindspore { 24 // ANF transform class. 25 // Either a primitive or a func_graph. 26 class MS_CORE_API FuncGraphTransform { 27 public: 28 enum Type { kGtPrimitive, kGtFuncGraph }; 29 30 explicit FuncGraphTransform(const PrimitivePtr &prim, const FuncGraphPtr &func_graph = nullptr, 31 const CNodePtr &primal_cnode = nullptr) prim_(prim)32 : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)), primal_cnode_(primal_cnode) {} 33 34 explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_, 35 const CNodePtr &primal_cnode = nullptr); 36 37 explicit FuncGraphTransform(const CNodePtr &primal_cnode, const PrimitivePtr &prim = func_graph_prim_, 38 const FuncGraphPtr &func_graph = nullptr) prim_(prim)39 : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)), primal_cnode_(primal_cnode) {} 40 FuncGraphTransform(const FuncGraphTransform & t)41 FuncGraphTransform(const FuncGraphTransform &t) 42 : prim_(t.prim_), func_graph_(t.func_graph_), primal_cnode_(t.primal_cnode_) {} 43 44 ~FuncGraphTransform() = default; 45 type()46 Type type() const { 47 if (IsFuncGraph()) { 48 return kGtFuncGraph; 49 } else { 50 return kGtPrimitive; 51 } 52 } 53 IsPrimitive()54 bool IsPrimitive() const { return (func_graph_.lock() == nullptr); } IsFuncGraph()55 bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); } func_graph()56 FuncGraphPtr func_graph() const { return func_graph_.lock(); } primitive()57 PrimitivePtr primitive() const { return prim_; } primal_cnode()58 CNodePtr primal_cnode() const { return primal_cnode_; } 59 60 FuncGraphTransform &operator=(const FuncGraphTransform &t) { 61 if (this != &t) { 62 prim_ = t.prim_; 63 func_graph_ = t.func_graph_; 64 primal_cnode_ = t.primal_cnode_; 65 } 66 return *this; 67 } 68 69 private: 70 PrimitivePtr prim_; 71 // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here. 72 // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in 73 // FPropRemapper::FinalizeGraph(). 74 FuncGraphWeakPtr func_graph_; 75 static const PrimitivePtr func_graph_prim_; 76 CNodePtr primal_cnode_; 77 }; 78 } // namespace mindspore 79 #endif // MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_TRANSFORM_H_ 80