1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include "mindapi/ir/func_graph.h"
19 #include "mindapi/src/helper.h"
20 #include "ir/anf.h"
21 #include "ir/value.h"
22 #include "ir/func_graph.h"
23 #include "ir/manager.h"
24 #include "ir/primitive.h"
25 #include "ir/graph_utils.h"
26
27 namespace mindspore::api {
28 using ValueImpl = mindspore::Value;
29 using AnfNodeImpl = mindspore::AnfNode;
30 using CNodeImpl = mindspore::CNode;
31 using PrimitiveImpl = mindspore::Primitive;
32 using ParameterImpl = mindspore::Parameter;
33 using FuncGraphImpl = mindspore::FuncGraph;
34 using FuncGraphManagerImpl = mindspore::FuncGraphManager;
35
36 MIND_API_BASE_IMPL(FuncGraph, FuncGraphImpl, Value);
37
get_inputs() const38 std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
39 auto &inputs = ToRef<FuncGraphImpl>(impl_).get_inputs();
40 return ToWrapperVector<AnfNode>(inputs);
41 }
42
parameters() const43 std::vector<AnfNodePtr> FuncGraph::parameters() const {
44 auto ¶ms = ToRef<FuncGraphImpl>(impl_).parameters();
45 return ToWrapperVector<AnfNode>(params);
46 }
47
add_parameter(const ParameterPtr & p)48 void FuncGraph::add_parameter(const ParameterPtr &p) {
49 auto param_impl = ToImpl<ParameterImpl>(p);
50 ToRef<FuncGraphImpl>(impl_).add_parameter(param_impl);
51 }
52
add_parameter()53 ParameterPtr FuncGraph::add_parameter() {
54 auto param_impl = ToRef<FuncGraphImpl>(impl_).add_parameter();
55 return ToWrapper<Parameter>(param_impl);
56 }
57
output() const58 AnfNodePtr FuncGraph::output() const {
59 auto output = ToRef<FuncGraphImpl>(impl_).output();
60 return ToWrapper<AnfNode>(output);
61 }
62
get_return() const63 CNodePtr FuncGraph::get_return() const {
64 auto ret = ToRef<FuncGraphImpl>(impl_).get_return();
65 return ToWrapper<CNode>(ret);
66 }
67
set_output(const AnfNodePtr & value,bool force_new_ret)68 void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
69 auto output = ToImpl<AnfNodeImpl>(value);
70 ToRef<FuncGraphImpl>(impl_).set_output(output, force_new_ret);
71 }
72
set_return(const CNodePtr & cnode)73 void FuncGraph::set_return(const CNodePtr &cnode) {
74 MS_EXCEPTION_IF_NULL(cnode);
75 auto cnode_impl = ToImpl<CNodeImpl>(cnode);
76 ToRef<FuncGraphImpl>(impl_).set_return(cnode_impl);
77 }
78
NewCNode(const std::vector<AnfNodePtr> & inputs)79 CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
80 auto inputs_impl = ToImplVector<AnfNodeImpl>(inputs);
81 auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(std::move(inputs_impl));
82 return ToWrapper<CNode>(cnode_impl);
83 }
84
NewCNode(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & prim_inputs)85 CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs) {
86 auto prim_impl = ToImpl<PrimitiveImpl>(primitive);
87 auto prim_inputs_impl = ToImplVector<AnfNodeImpl>(prim_inputs);
88 auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(prim_impl, prim_inputs_impl);
89 return ToWrapper<CNode>(cnode_impl);
90 }
91
nodes() const92 std::vector<AnfNodePtr> FuncGraph::nodes() const {
93 auto &nodes = ToRef<FuncGraphImpl>(impl_).nodes();
94 return ToWrapperVector<AnfNode>(nodes);
95 }
96
has_attr(const std::string & key) const97 bool FuncGraph::has_attr(const std::string &key) const { return ToRef<FuncGraphImpl>(impl_).has_attr(key); }
98
get_attr(const std::string & key) const99 ValuePtr FuncGraph::get_attr(const std::string &key) const {
100 auto v = ToRef<FuncGraphImpl>(impl_).get_attr(key);
101 return ToWrapper<Value>(v);
102 }
103
set_attr(const std::string & key,const ValuePtr & value)104 void FuncGraph::set_attr(const std::string &key, const ValuePtr &value) {
105 auto value_impl = ToImpl<ValueImpl>(value);
106 ToRef<FuncGraphImpl>(impl_).set_attr(key, value_impl);
107 }
108
manager() const109 FuncGraphManagerPtr FuncGraph::manager() const {
110 auto manager = ToRef<FuncGraphImpl>(impl_).manager();
111 if (manager == nullptr) {
112 return nullptr;
113 }
114 return MakeShared<FuncGraphManager>(manager);
115 }
116
Create()117 FuncGraphPtr FuncGraph::Create() {
118 auto fg = std::make_shared<FuncGraphImpl>();
119 return ToWrapper<FuncGraph>(fg);
120 }
121
TopoSort(const AnfNodePtr & node)122 std::vector<AnfNodePtr> FuncGraph::TopoSort(const AnfNodePtr &node) {
123 auto node_impl = ToImpl<AnfNodeImpl>(node);
124 if (node_impl == nullptr) {
125 return {};
126 }
127 auto sorted = mindspore::TopoSort(node_impl);
128 return ToWrapperVector<AnfNode>(sorted);
129 }
130
131 // FuncGraphManager is not derived from Base, we implement it directly.
FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> & impl)132 FuncGraphManager::FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl) : impl_(impl) {
133 MS_EXCEPTION_IF_NULL(impl_);
134 }
135
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node)136 bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
137 return impl_->Replace(ToImpl<AnfNodeImpl>(old_node), ToImpl<AnfNodeImpl>(new_node));
138 }
139
SetEdge(const AnfNodePtr & node,int index,const AnfNodePtr & value)140 void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
141 return impl_->SetEdge(ToImpl<AnfNodeImpl>(node), index, ToImpl<AnfNodeImpl>(value));
142 }
143
AddEdge(const AnfNodePtr & node,const AnfNodePtr & value)144 void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) {
145 return impl_->AddEdge(ToImpl<AnfNodeImpl>(node), ToImpl<AnfNodeImpl>(value));
146 }
147
GetUsers(const AnfNodePtr & node) const148 std::vector<std::pair<AnfNodePtr, int>> FuncGraphManager::GetUsers(const AnfNodePtr &node) const {
149 auto &node_users = impl_->node_users();
150 auto iter = node_users.find(ToImpl<AnfNodeImpl>(node));
151 if (iter == node_users.end()) {
152 return {};
153 }
154 auto &users_impl = iter->second;
155 std::vector<std::pair<AnfNodePtr, int>> users;
156 users.reserve(users_impl.size());
157 (void)std::transform(users_impl.begin(), users_impl.end(), std::back_inserter(users),
158 [](const auto &user) { return std::make_pair(ToWrapper<AnfNode>(user.first), user.second); });
159 return users;
160 }
161
Manage(const FuncGraphPtr & func_graph,bool manage)162 FuncGraphManagerPtr FuncGraphManager::Manage(const FuncGraphPtr &func_graph, bool manage) {
163 auto fg_impl = ToImpl<FuncGraphImpl>(func_graph);
164 auto mgr_impl = mindspore::Manage(fg_impl, manage);
165 if (mgr_impl == nullptr) {
166 return nullptr;
167 }
168 return MakeShared<FuncGraphManager>(mgr_impl);
169 }
170 } // namespace mindspore::api
171