• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include "mindapi/ir/func_graph.h"
19 #include "mindapi/src/helper.h"
20 #include "ir/anf.h"
21 #include "ir/value.h"
22 #include "ir/func_graph.h"
23 #include "ir/manager.h"
24 #include "ir/primitive.h"
25 #include "ir/graph_utils.h"
26 
27 namespace mindspore::api {
28 using ValueImpl = mindspore::Value;
29 using AnfNodeImpl = mindspore::AnfNode;
30 using CNodeImpl = mindspore::CNode;
31 using PrimitiveImpl = mindspore::Primitive;
32 using ParameterImpl = mindspore::Parameter;
33 using FuncGraphImpl = mindspore::FuncGraph;
34 using FuncGraphManagerImpl = mindspore::FuncGraphManager;
35 
36 MIND_API_BASE_IMPL(FuncGraph, FuncGraphImpl, Value);
37 
get_inputs() const38 std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
39   auto &inputs = ToRef<FuncGraphImpl>(impl_).get_inputs();
40   return ToWrapperVector<AnfNode>(inputs);
41 }
42 
parameters() const43 std::vector<AnfNodePtr> FuncGraph::parameters() const {
44   auto &params = ToRef<FuncGraphImpl>(impl_).parameters();
45   return ToWrapperVector<AnfNode>(params);
46 }
47 
add_parameter(const ParameterPtr & p)48 void FuncGraph::add_parameter(const ParameterPtr &p) {
49   auto param_impl = ToImpl<ParameterImpl>(p);
50   ToRef<FuncGraphImpl>(impl_).add_parameter(param_impl);
51 }
52 
add_parameter()53 ParameterPtr FuncGraph::add_parameter() {
54   auto param_impl = ToRef<FuncGraphImpl>(impl_).add_parameter();
55   return ToWrapper<Parameter>(param_impl);
56 }
57 
output() const58 AnfNodePtr FuncGraph::output() const {
59   auto output = ToRef<FuncGraphImpl>(impl_).output();
60   return ToWrapper<AnfNode>(output);
61 }
62 
get_return() const63 CNodePtr FuncGraph::get_return() const {
64   auto ret = ToRef<FuncGraphImpl>(impl_).get_return();
65   return ToWrapper<CNode>(ret);
66 }
67 
set_output(const AnfNodePtr & value,bool force_new_ret)68 void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
69   auto output = ToImpl<AnfNodeImpl>(value);
70   ToRef<FuncGraphImpl>(impl_).set_output(output, force_new_ret);
71 }
72 
set_return(const CNodePtr & cnode)73 void FuncGraph::set_return(const CNodePtr &cnode) {
74   MS_EXCEPTION_IF_NULL(cnode);
75   auto cnode_impl = ToImpl<CNodeImpl>(cnode);
76   ToRef<FuncGraphImpl>(impl_).set_return(cnode_impl);
77 }
78 
NewCNode(const std::vector<AnfNodePtr> & inputs)79 CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
80   auto inputs_impl = ToImplVector<AnfNodeImpl>(inputs);
81   auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(std::move(inputs_impl));
82   return ToWrapper<CNode>(cnode_impl);
83 }
84 
NewCNode(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & prim_inputs)85 CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs) {
86   auto prim_impl = ToImpl<PrimitiveImpl>(primitive);
87   auto prim_inputs_impl = ToImplVector<AnfNodeImpl>(prim_inputs);
88   auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(prim_impl, prim_inputs_impl);
89   return ToWrapper<CNode>(cnode_impl);
90 }
91 
nodes() const92 std::vector<AnfNodePtr> FuncGraph::nodes() const {
93   auto &nodes = ToRef<FuncGraphImpl>(impl_).nodes();
94   return ToWrapperVector<AnfNode>(nodes);
95 }
96 
has_attr(const std::string & key) const97 bool FuncGraph::has_attr(const std::string &key) const { return ToRef<FuncGraphImpl>(impl_).has_attr(key); }
98 
get_attr(const std::string & key) const99 ValuePtr FuncGraph::get_attr(const std::string &key) const {
100   auto v = ToRef<FuncGraphImpl>(impl_).get_attr(key);
101   return ToWrapper<Value>(v);
102 }
103 
set_attr(const std::string & key,const ValuePtr & value)104 void FuncGraph::set_attr(const std::string &key, const ValuePtr &value) {
105   auto value_impl = ToImpl<ValueImpl>(value);
106   ToRef<FuncGraphImpl>(impl_).set_attr(key, value_impl);
107 }
108 
manager() const109 FuncGraphManagerPtr FuncGraph::manager() const {
110   auto manager = ToRef<FuncGraphImpl>(impl_).manager();
111   if (manager == nullptr) {
112     return nullptr;
113   }
114   return MakeShared<FuncGraphManager>(manager);
115 }
116 
Create()117 FuncGraphPtr FuncGraph::Create() {
118   auto fg = std::make_shared<FuncGraphImpl>();
119   return ToWrapper<FuncGraph>(fg);
120 }
121 
TopoSort(const AnfNodePtr & node)122 std::vector<AnfNodePtr> FuncGraph::TopoSort(const AnfNodePtr &node) {
123   auto node_impl = ToImpl<AnfNodeImpl>(node);
124   if (node_impl == nullptr) {
125     return {};
126   }
127   auto sorted = mindspore::TopoSort(node_impl);
128   return ToWrapperVector<AnfNode>(sorted);
129 }
130 
131 // FuncGraphManager is not derived from Base, we implement it directly.
FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> & impl)132 FuncGraphManager::FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl) : impl_(impl) {
133   MS_EXCEPTION_IF_NULL(impl_);
134 }
135 
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node)136 bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
137   return impl_->Replace(ToImpl<AnfNodeImpl>(old_node), ToImpl<AnfNodeImpl>(new_node));
138 }
139 
SetEdge(const AnfNodePtr & node,int index,const AnfNodePtr & value)140 void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
141   return impl_->SetEdge(ToImpl<AnfNodeImpl>(node), index, ToImpl<AnfNodeImpl>(value));
142 }
143 
AddEdge(const AnfNodePtr & node,const AnfNodePtr & value)144 void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) {
145   return impl_->AddEdge(ToImpl<AnfNodeImpl>(node), ToImpl<AnfNodeImpl>(value));
146 }
147 
GetUsers(const AnfNodePtr & node) const148 std::vector<std::pair<AnfNodePtr, int>> FuncGraphManager::GetUsers(const AnfNodePtr &node) const {
149   auto &node_users = impl_->node_users();
150   auto iter = node_users.find(ToImpl<AnfNodeImpl>(node));
151   if (iter == node_users.end()) {
152     return {};
153   }
154   auto &users_impl = iter->second;
155   std::vector<std::pair<AnfNodePtr, int>> users;
156   users.reserve(users_impl.size());
157   (void)std::transform(users_impl.begin(), users_impl.end(), std::back_inserter(users),
158                        [](const auto &user) { return std::make_pair(ToWrapper<AnfNode>(user.first), user.second); });
159   return users;
160 }
161 
Manage(const FuncGraphPtr & func_graph,bool manage)162 FuncGraphManagerPtr FuncGraphManager::Manage(const FuncGraphPtr &func_graph, bool manage) {
163   auto fg_impl = ToImpl<FuncGraphImpl>(func_graph);
164   auto mgr_impl = mindspore::Manage(fg_impl, manage);
165   if (mgr_impl == nullptr) {
166     return nullptr;
167   }
168   return MakeShared<FuncGraphManager>(mgr_impl);
169 }
170 }  // namespace mindspore::api
171