1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/optimizer/dynamic_shape/link_custom_op.h"
18
19 #include <memory>
20 #include <vector>
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "utils/anf_utils.h"
24 #include "include/backend/anf_runtime_algorithm.h"
25 #include "include/common/utils/anfalgo.h"
26 #include "include/backend/optimizer/helper.h"
27 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
28 #include "abstract/ops/primitive_infer_map.h"
29
30 namespace mindspore {
31 namespace opt::dynamic_shape {
32 constexpr size_t kTupleFirstItemIndex = 0;
33 constexpr size_t kFirstDataInputIndex = 1;
34
InsertDepend(const FuncGraphPtr & g,const AnfNodePtr & prev,const AnfNodePtr & next,AnfNodePtrList * depend_nodes)35 void LinkCustomOp::InsertDepend(const FuncGraphPtr &g, const AnfNodePtr &prev, const AnfNodePtr &next,
36 AnfNodePtrList *depend_nodes) {
37 MS_EXCEPTION_IF_NULL(g);
38 MS_EXCEPTION_IF_NULL(prev);
39 MS_EXCEPTION_IF_NULL(next);
40 MS_EXCEPTION_IF_NULL(depend_nodes);
41
42 DependPair cur_pair = std::make_pair(prev, next);
43 if (added_set_.count(cur_pair) > 0) {
44 return;
45 }
46
47 // add depend from prev to next
48 auto depend_node = g->NewCNode(
49 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), next, prev});
50 MS_EXCEPTION_IF_NULL(depend_node);
51 depend_nodes->push_back(depend_node);
52 (void)added_set_.insert(cur_pair);
53 }
54
LinkInternalOp(const FuncGraphPtr & g,const AnfNodePtr & node,AnfNodePtrList * depend_nodes)55 bool LinkCustomOp::LinkInternalOp(const FuncGraphPtr &g, const AnfNodePtr &node, AnfNodePtrList *depend_nodes) {
56 MS_EXCEPTION_IF_NULL(g);
57 MS_EXCEPTION_IF_NULL(node);
58 MS_EXCEPTION_IF_NULL(depend_nodes);
59
60 bool changed = false;
61 auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(node);
62 if (custom_nodes.infer_node != nullptr && custom_nodes.init_node != nullptr) {
63 InsertDepend(g, custom_nodes.infer_node, custom_nodes.init_node, depend_nodes); // link infer => init
64 InsertDepend(g, custom_nodes.init_node, node, depend_nodes); // link init => launch
65 changed = true;
66 }
67
68 return changed;
69 }
70
LinkInputOp(const FuncGraphPtr & g,const CNodePtr & cnode,AnfNodePtrList * depend_nodes)71 bool LinkCustomOp::LinkInputOp(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList *depend_nodes) {
72 MS_EXCEPTION_IF_NULL(g);
73 MS_EXCEPTION_IF_NULL(cnode);
74 MS_EXCEPTION_IF_NULL(depend_nodes);
75 bool changed = false;
76 auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(cnode);
77 if (custom_nodes.infer_node == nullptr) {
78 return changed;
79 }
80 size_t input_num = common::AnfAlgo::GetInputNum(cnode);
81 for (size_t i = 0; i < input_num; ++i) {
82 auto prev = common::AnfAlgo::GetPrevNodeOutput(cnode, i);
83 const auto &prev_node = prev.first;
84 if (prev_node == nullptr) {
85 continue;
86 }
87 if (!CustomActorNodeManager::Instance().IsRegistered(prev_node)) {
88 // when its subgraph and its input is a dynamic_shape_parameter, link prev_parameter => curr.infer
89 if (prev_node->isa<Parameter>()) {
90 auto prev_parameter = prev_node->cast<ParameterPtr>();
91 MS_EXCEPTION_IF_NULL(prev_parameter);
92 if (prev_parameter->has_dynamic_shape()) {
93 InsertDepend(g, prev_node, custom_nodes.infer_node, depend_nodes);
94 changed = true;
95 }
96 }
97 continue;
98 }
99 auto prev_custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(prev_node);
100 if (prev_custom_nodes.infer_node != nullptr) {
101 // link prev.infer => curr.infer
102 InsertDepend(g, prev_custom_nodes.infer_node, custom_nodes.infer_node, depend_nodes);
103 changed = true;
104 }
105
106 // if the shape of prev_node is set after launch, need to link prev_node's launch to cur_node's infer
107 if (AnfAlgo::IsNeedUpdateShapeAndTypeAfterLaunch(prev_node)) {
108 // link prev.launch => curr.infer
109 InsertDepend(g, prev_node, custom_nodes.infer_node, depend_nodes);
110 changed = true;
111 }
112 }
113 return changed;
114 }
115
LinkDependSync(const FuncGraphPtr & g,const CNodePtr & cnode,AnfNodePtrList * depend_nodes)116 bool LinkCustomOp::LinkDependSync(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList *depend_nodes) {
117 MS_EXCEPTION_IF_NULL(g);
118 MS_EXCEPTION_IF_NULL(cnode);
119 MS_EXCEPTION_IF_NULL(depend_nodes);
120 bool changed = false;
121 auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(cnode);
122 if (custom_nodes.infer_node == nullptr) {
123 return changed;
124 }
125
126 auto dynamic_shape_depends = abstract::GetValueDependArgIndices(cnode);
127 if (dynamic_shape_depends.empty()) {
128 return changed;
129 }
130
131 for (auto depend_index : dynamic_shape_depends) {
132 auto prev = common::AnfAlgo::GetPrevNodeOutput(cnode, LongToSize(depend_index));
133 const auto &prev_node = prev.first;
134 if (prev_node == nullptr || !CustomActorNodeManager::Instance().IsRegistered(prev_node)) {
135 continue;
136 }
137
138 // If previous node is dynamic, so it was already link.
139 if (AnfAlgo::IsNeedUpdateShapeAndTypeAfterLaunch(prev_node)) {
140 continue;
141 }
142
143 // Link prev_node.launch => cur_node.infer.
144 InsertDepend(g, prev_node, custom_nodes.infer_node, depend_nodes);
145 changed = true;
146 }
147 return changed;
148 }
149
150 /**
151 * @brief Attach Custom's Depend nodes with additional MakeTuple and TupleGetItem before graph return.
152 *
153 * %0 = A
154 * return %0
155 * ---->
156 * %0 = A
157 * %1 = MakeTuple(%0, %depend0, %depend1...)
158 * %2 = TupleGetItem(%1, 0)
159 * return %2
160 *
161 * @param g Graph.
162 * @param depend_nodes Custom's Depend nodes.
163 */
AttachDependNodes(const FuncGraphPtr & g,const AnfNodePtrList & depend_nodes) const164 void LinkCustomOp::AttachDependNodes(const FuncGraphPtr &g, const AnfNodePtrList &depend_nodes) const {
165 if (depend_nodes.empty()) {
166 return;
167 }
168
169 MS_EXCEPTION_IF_NULL(g);
170 auto return_node = g->get_return();
171 MS_EXCEPTION_IF_NULL(return_node);
172 auto output_node = return_node->input(kFirstDataInputIndex);
173 MS_EXCEPTION_IF_NULL(output_node);
174
175 // New MakeTuple node
176 auto mk_inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), output_node};
177 (void)mk_inputs.insert(mk_inputs.cend(), depend_nodes.cbegin(), depend_nodes.cend());
178 auto make_tuple_node = g->NewCNode(mk_inputs);
179
180 // Get first element item form that maketuple and return.
181 auto get_1st_item =
182 g->NewCNode(AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(mindspore::kTupleGetItemOpName)),
183 make_tuple_node, NewValueNode(SizeToLong(kTupleFirstItemIndex))});
184 // The getitem node always obtains the first input of the maketuple, which is the output in the original graph,
185 // so set the abstract of the output to the getitem node.
186 MS_EXCEPTION_IF_NULL(get_1st_item);
187 get_1st_item->set_abstract(output_node->abstract());
188 // Attach back.
189 return_node->set_input(kFirstDataInputIndex, get_1st_item);
190 }
191
Run(const FuncGraphPtr & func_graph)192 bool LinkCustomOp::Run(const FuncGraphPtr &func_graph) {
193 MS_EXCEPTION_IF_NULL(func_graph);
194 bool changed = false;
195 AnfNodePtrList depend_nodes;
196 const auto &node_list = TopoSort(func_graph->get_return());
197 added_set_.clear();
198 for (const auto &node : node_list) {
199 MS_EXCEPTION_IF_NULL(node);
200 CNodePtr cnode = node->cast<CNodePtr>();
201 if (cnode == nullptr || !CustomActorNodeManager::Instance().IsRegistered(cnode)) {
202 continue;
203 }
204
205 changed = LinkInternalOp(func_graph, cnode, &depend_nodes) || changed;
206 changed = LinkInputOp(func_graph, cnode, &depend_nodes) || changed;
207 changed = LinkDependSync(func_graph, cnode, &depend_nodes) || changed;
208 }
209
210 CustomActorNodeManager::Instance().Reset();
211
212 if (changed) {
213 AttachDependNodes(func_graph, depend_nodes);
214
215 // Rebuild graph's edge.
216 auto mng = func_graph->manager();
217 if (mng == nullptr) {
218 mng = Manage(func_graph, true);
219 func_graph->set_manager(mng);
220 }
221 MS_EXCEPTION_IF_NULL(mng);
222 mng->RemoveRoots();
223 mng->KeepRoots({func_graph});
224 }
225
226 return changed;
227 }
228 } // namespace opt::dynamic_shape
229 } // namespace mindspore
230