• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/optimizer/dynamic_shape/link_custom_op.h"
18 
19 #include <memory>
20 #include <vector>
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "utils/anf_utils.h"
24 #include "include/backend/anf_runtime_algorithm.h"
25 #include "include/common/utils/anfalgo.h"
26 #include "include/backend/optimizer/helper.h"
27 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
28 #include "abstract/ops/primitive_infer_map.h"
29 
30 namespace mindspore {
31 namespace opt::dynamic_shape {
32 constexpr size_t kTupleFirstItemIndex = 0;
33 constexpr size_t kFirstDataInputIndex = 1;
34 
InsertDepend(const FuncGraphPtr & g,const AnfNodePtr & prev,const AnfNodePtr & next,AnfNodePtrList * depend_nodes)35 void LinkCustomOp::InsertDepend(const FuncGraphPtr &g, const AnfNodePtr &prev, const AnfNodePtr &next,
36                                 AnfNodePtrList *depend_nodes) {
37   MS_EXCEPTION_IF_NULL(g);
38   MS_EXCEPTION_IF_NULL(prev);
39   MS_EXCEPTION_IF_NULL(next);
40   MS_EXCEPTION_IF_NULL(depend_nodes);
41 
42   DependPair cur_pair = std::make_pair(prev, next);
43   if (added_set_.count(cur_pair) > 0) {
44     return;
45   }
46 
47   // add depend from prev to next
48   auto depend_node = g->NewCNode(
49     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), next, prev});
50   MS_EXCEPTION_IF_NULL(depend_node);
51   depend_nodes->push_back(depend_node);
52   (void)added_set_.insert(cur_pair);
53 }
54 
LinkInternalOp(const FuncGraphPtr & g,const AnfNodePtr & node,AnfNodePtrList * depend_nodes)55 bool LinkCustomOp::LinkInternalOp(const FuncGraphPtr &g, const AnfNodePtr &node, AnfNodePtrList *depend_nodes) {
56   MS_EXCEPTION_IF_NULL(g);
57   MS_EXCEPTION_IF_NULL(node);
58   MS_EXCEPTION_IF_NULL(depend_nodes);
59 
60   bool changed = false;
61   auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(node);
62   if (custom_nodes.infer_node != nullptr && custom_nodes.init_node != nullptr) {
63     InsertDepend(g, custom_nodes.infer_node, custom_nodes.init_node, depend_nodes);  // link infer => init
64     InsertDepend(g, custom_nodes.init_node, node, depend_nodes);                     // link init => launch
65     changed = true;
66   }
67 
68   return changed;
69 }
70 
LinkInputOp(const FuncGraphPtr & g,const CNodePtr & cnode,AnfNodePtrList * depend_nodes)71 bool LinkCustomOp::LinkInputOp(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList *depend_nodes) {
72   MS_EXCEPTION_IF_NULL(g);
73   MS_EXCEPTION_IF_NULL(cnode);
74   MS_EXCEPTION_IF_NULL(depend_nodes);
75   bool changed = false;
76   auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(cnode);
77   if (custom_nodes.infer_node == nullptr) {
78     return changed;
79   }
80   size_t input_num = common::AnfAlgo::GetInputNum(cnode);
81   for (size_t i = 0; i < input_num; ++i) {
82     auto prev = common::AnfAlgo::GetPrevNodeOutput(cnode, i);
83     const auto &prev_node = prev.first;
84     if (prev_node == nullptr) {
85       continue;
86     }
87     if (!CustomActorNodeManager::Instance().IsRegistered(prev_node)) {
88       // when its subgraph and its input is a dynamic_shape_parameter, link prev_parameter => curr.infer
89       if (prev_node->isa<Parameter>()) {
90         auto prev_parameter = prev_node->cast<ParameterPtr>();
91         MS_EXCEPTION_IF_NULL(prev_parameter);
92         if (prev_parameter->has_dynamic_shape()) {
93           InsertDepend(g, prev_node, custom_nodes.infer_node, depend_nodes);
94           changed = true;
95         }
96       }
97       continue;
98     }
99     auto prev_custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(prev_node);
100     if (prev_custom_nodes.infer_node != nullptr) {
101       // link prev.infer => curr.infer
102       InsertDepend(g, prev_custom_nodes.infer_node, custom_nodes.infer_node, depend_nodes);
103       changed = true;
104     }
105 
106     // if the shape of prev_node is set after launch, need to link prev_node's launch to cur_node's infer
107     if (AnfAlgo::IsNeedUpdateShapeAndTypeAfterLaunch(prev_node)) {
108       // link prev.launch => curr.infer
109       InsertDepend(g, prev_node, custom_nodes.infer_node, depend_nodes);
110       changed = true;
111     }
112   }
113   return changed;
114 }
115 
LinkDependSync(const FuncGraphPtr & g,const CNodePtr & cnode,AnfNodePtrList * depend_nodes)116 bool LinkCustomOp::LinkDependSync(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList *depend_nodes) {
117   MS_EXCEPTION_IF_NULL(g);
118   MS_EXCEPTION_IF_NULL(cnode);
119   MS_EXCEPTION_IF_NULL(depend_nodes);
120   bool changed = false;
121   auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(cnode);
122   if (custom_nodes.infer_node == nullptr) {
123     return changed;
124   }
125 
126   auto dynamic_shape_depends = abstract::GetValueDependArgIndices(cnode);
127   if (dynamic_shape_depends.empty()) {
128     return changed;
129   }
130 
131   for (auto depend_index : dynamic_shape_depends) {
132     auto prev = common::AnfAlgo::GetPrevNodeOutput(cnode, LongToSize(depend_index));
133     const auto &prev_node = prev.first;
134     if (prev_node == nullptr || !CustomActorNodeManager::Instance().IsRegistered(prev_node)) {
135       continue;
136     }
137 
138     // If previous node is dynamic, so it was already link.
139     if (AnfAlgo::IsNeedUpdateShapeAndTypeAfterLaunch(prev_node)) {
140       continue;
141     }
142 
143     // Link prev_node.launch => cur_node.infer.
144     InsertDepend(g, prev_node, custom_nodes.infer_node, depend_nodes);
145     changed = true;
146   }
147   return changed;
148 }
149 
150 /**
151  * @brief Attach Custom's Depend nodes with additional MakeTuple and TupleGetItem before graph return.
152  *
153  *          %0 = A
154  *          return %0
155  *          ---->
156  *          %0 = A
157  *          %1 = MakeTuple(%0, %depend0, %depend1...)
158  *          %2 = TupleGetItem(%1, 0)
159  *          return %2
160  *
161  * @param g Graph.
162  * @param depend_nodes Custom's Depend nodes.
163  */
AttachDependNodes(const FuncGraphPtr & g,const AnfNodePtrList & depend_nodes) const164 void LinkCustomOp::AttachDependNodes(const FuncGraphPtr &g, const AnfNodePtrList &depend_nodes) const {
165   if (depend_nodes.empty()) {
166     return;
167   }
168 
169   MS_EXCEPTION_IF_NULL(g);
170   auto return_node = g->get_return();
171   MS_EXCEPTION_IF_NULL(return_node);
172   auto output_node = return_node->input(kFirstDataInputIndex);
173   MS_EXCEPTION_IF_NULL(output_node);
174 
175   // New MakeTuple node
176   auto mk_inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), output_node};
177   (void)mk_inputs.insert(mk_inputs.cend(), depend_nodes.cbegin(), depend_nodes.cend());
178   auto make_tuple_node = g->NewCNode(mk_inputs);
179 
180   // Get first element item form that maketuple and return.
181   auto get_1st_item =
182     g->NewCNode(AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(mindspore::kTupleGetItemOpName)),
183                                make_tuple_node, NewValueNode(SizeToLong(kTupleFirstItemIndex))});
184   // The getitem node always obtains the first input of the maketuple, which is the output in the original graph,
185   // so set the abstract of the output to the getitem node.
186   MS_EXCEPTION_IF_NULL(get_1st_item);
187   get_1st_item->set_abstract(output_node->abstract());
188   // Attach back.
189   return_node->set_input(kFirstDataInputIndex, get_1st_item);
190 }
191 
Run(const FuncGraphPtr & func_graph)192 bool LinkCustomOp::Run(const FuncGraphPtr &func_graph) {
193   MS_EXCEPTION_IF_NULL(func_graph);
194   bool changed = false;
195   AnfNodePtrList depend_nodes;
196   const auto &node_list = TopoSort(func_graph->get_return());
197   added_set_.clear();
198   for (const auto &node : node_list) {
199     MS_EXCEPTION_IF_NULL(node);
200     CNodePtr cnode = node->cast<CNodePtr>();
201     if (cnode == nullptr || !CustomActorNodeManager::Instance().IsRegistered(cnode)) {
202       continue;
203     }
204 
205     changed = LinkInternalOp(func_graph, cnode, &depend_nodes) || changed;
206     changed = LinkInputOp(func_graph, cnode, &depend_nodes) || changed;
207     changed = LinkDependSync(func_graph, cnode, &depend_nodes) || changed;
208   }
209 
210   CustomActorNodeManager::Instance().Reset();
211 
212   if (changed) {
213     AttachDependNodes(func_graph, depend_nodes);
214 
215     // Rebuild graph's edge.
216     auto mng = func_graph->manager();
217     if (mng == nullptr) {
218       mng = Manage(func_graph, true);
219       func_graph->set_manager(mng);
220     }
221     MS_EXCEPTION_IF_NULL(mng);
222     mng->RemoveRoots();
223     mng->KeepRoots({func_graph});
224   }
225 
226   return changed;
227 }
228 }  // namespace opt::dynamic_shape
229 }  // namespace mindspore
230