1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h"
18 #include <utility>
19 #include <vector>
20 #include <memory>
21 #include <string>
22 #include "backend/kernel_compiler/oplib/oplib.h"
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "backend/session/kernel_graph.h"
25 #include "backend/optimizer/common/helper.h"
26
27 namespace mindspore {
28 namespace opt {
FindRefOriginNode(const AnfNodePtr & node) const29 session::KernelWithIndex DealRefAndSpiltUnSupportedTransdata::FindRefOriginNode(const AnfNodePtr &node) const {
30 session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0);
31 AnfNodePtr cur_node = kernel_with_index.first;
32 size_t cur_out_index = kernel_with_index.second;
33 MS_EXCEPTION_IF_NULL(cur_node);
34 if (cur_node->isa<CNode>()) {
35 auto cnode = cur_node->cast<CNodePtr>();
36 MS_EXCEPTION_IF_NULL(cnode);
37 std::string op_name = AnfAlgo::GetCNodeName(cnode);
38 auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
39 // deal ref op
40 if (op_info != nullptr && op_info->is_ref()) {
41 auto ref_infos = op_info->ref_infos();
42 if (ref_infos.count(cur_out_index) != 0) {
43 auto in_index = ref_infos.at(cur_out_index);
44 if (in_index > cnode->inputs().size()) {
45 MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size()
46 << ", ref info is " << cur_out_index;
47 }
48 AnfNodePtr next_node = cnode->input(in_index + 1);
49 return FindRefOriginNode(next_node);
50 }
51 }
52
53 // deal special (trans,cast,reshape) op and nop-node
54 if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() ||
55 op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName || opt::IsNopNode(cnode)) {
56 AnfNodePtr next_node = cnode->input(1);
57 return FindRefOriginNode(next_node);
58 }
59 }
60
61 return kernel_with_index;
62 }
63
AddRefNodePairToKernelGraph(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const size_t output_index,const size_t input_index) const64 void DealRefAndSpiltUnSupportedTransdata::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph,
65 const CNodePtr &cnode, const size_t output_index,
66 const size_t input_index) const {
67 // record the ref_pair
68 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
69 MS_EXCEPTION_IF_NULL(kernel_graph);
70 session::AnfWithOutIndex final_pair = std::make_pair(cnode, output_index);
71 session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0);
72 kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index);
73 }
74
AddRefPairToKernelGraph(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const AnfNodePtr & get_item,const AnfNodePtr & final_node,size_t final_index,const session::KernelWithIndex & origin_pair) const75 void DealRefAndSpiltUnSupportedTransdata::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
76 const AnfNodePtr &get_item,
77 const AnfNodePtr &final_node, size_t final_index,
78 const session::KernelWithIndex &origin_pair) const {
79 // record the ref_pair
80 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
81 MS_EXCEPTION_IF_NULL(kernel_graph);
82 // if the final node is get item, means no trans or cast op is added, the final node is itself
83 // so add the pair for itself, because the get item will removed later
84 auto final_ref = (final_node == get_item ? cnode : final_node);
85 session::AnfWithOutIndex final_pair = std::make_pair(final_ref, final_index);
86 if (kernel_graph->IsInRefOutputMap(final_pair)) {
87 MS_LOG(EXCEPTION) << "ref_pair is already in ref map, node is " << final_ref->DebugString() << ", index is "
88 << final_index;
89 }
90 MS_LOG(DEBUG) << "Add Ref pair, final {node ptr " << final_pair.first.get() << " , info is "
91 << final_pair.first->DebugString() << " , index is " << final_pair.second << "}, origin {node ptr "
92 << origin_pair.first.get() << ", info is " << origin_pair.first->DebugString() << " : index "
93 << origin_pair.second << "}";
94 kernel_graph->AddRefCorrespondPairs(final_pair, origin_pair);
95 }
96
97 // if get_item is nullptr, the additional node will link to the cnode
98 // else the additional node will link to the get_item node (the get_item node link to cnode)
AddAdditionalToRefOutput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t output_index,size_t input_index,const CNodePtr & get_item) const99 CNodePtr DealRefAndSpiltUnSupportedTransdata::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph,
100 const CNodePtr &cnode, size_t output_index,
101 size_t input_index,
102 const CNodePtr &get_item) const {
103 CNodePtr final_node = (get_item == nullptr ? cnode : get_item);
104 bool need_refresh_ref_addr = false;
105 size_t final_index = output_index;
106 AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index);
107 session::KernelWithIndex origin_pair = FindRefOriginNode(input_node);
108 MS_EXCEPTION_IF_NULL(origin_pair.first);
109 if (!origin_pair.first->isa<Parameter>()) {
110 MS_LOG(WARNING) << "ref op origin node is not parameter";
111 }
112 MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is "
113 << origin_pair.first->DebugString() << ", index is " << origin_pair.second;
114 auto origin_format = AnfAlgo::GetOutputFormat(origin_pair.first, origin_pair.second);
115 auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second);
116 auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index);
117 auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index);
118 auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
119 auto detail_shape = AnfAlgo::GetOutputDetailShape(cnode, output_index);
120 // insert trans
121 if (origin_format != cur_format && cur_shape.size() > 1) {
122 auto kernel_select = std::make_shared<KernelSelect>();
123 final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::kPrimTransData->name());
124 RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
125 final_node = SplitTransdataIfNotSupported(func_graph, final_node);
126 final_index = 0;
127 need_refresh_ref_addr = true;
128 MS_EXCEPTION_IF_NULL(final_node);
129 MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
130 }
131 // insert cast
132 if (origin_type != cur_type) {
133 final_node =
134 AddCastOpNodeToGraph(func_graph, final_node, origin_format, cur_type, origin_type, detail_shape, cur_type);
135 MS_EXCEPTION_IF_NULL(final_node);
136 final_node->set_scope(cnode->scope());
137 final_index = 0;
138 need_refresh_ref_addr = true;
139 MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString();
140 }
141 // add ref pair
142 AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair);
143 if (need_refresh_ref_addr) {
144 AddRefNodePairToKernelGraph(func_graph, cnode, output_index, input_index);
145 }
146 // insert depend
147 if (origin_format != cur_format || origin_type != cur_type) {
148 final_node = MakeDependency(get_item, final_node, cnode, func_graph);
149 MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString();
150 }
151 return final_node;
152 }
153
MakeDependency(const CNodePtr & get_item,const CNodePtr & final_node,const CNodePtr & cnode,const FuncGraphPtr & func_graph) const154 CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
155 const CNodePtr &cnode,
156 const FuncGraphPtr &func_graph) const {
157 std::vector<AnfNodePtr> depend_nodes;
158 if (get_item != nullptr) {
159 depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
160 } else {
161 depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node};
162 }
163 return func_graph->NewCNode(depend_nodes);
164 }
DealRefForMultipleOutput(const FuncGraphPtr & func_graph,const CNodePtr & orig_cnode,const std::shared_ptr<kernel::OpInfo> & op_info) const165 CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
166 const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
167 MS_EXCEPTION_IF_NULL(func_graph);
168 auto manager = func_graph->manager();
169 MS_EXCEPTION_IF_NULL(manager);
170 auto cnode = orig_cnode;
171 auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode);
172 if (!update_states.empty()) {
173 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
174 MS_EXCEPTION_IF_NULL(kernel_graph);
175 cnode = kernel_graph->NewCNode(orig_cnode);
176 MS_EXCEPTION_IF_NULL(cnode);
177 cnode->set_inputs(orig_cnode->inputs());
178 for (auto &update_state : update_states) {
179 manager->SetEdge(update_state.first, update_state.second, cnode);
180 }
181 }
182 MS_EXCEPTION_IF_NULL(op_info);
183 auto ref_infos = op_info->ref_infos();
184 std::vector<AnfNodePtr> make_tuple_inputs;
185 AbstractBasePtrList abstract_list;
186 (void)make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
187 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
188 for (size_t output_index = 0; output_index < output_num; ++output_index) {
189 CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index);
190 // deal with ref output
191 if (ref_infos.count(output_index) != 0) {
192 auto input_index = ref_infos.at(output_index);
193 final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node);
194 }
195 MS_EXCEPTION_IF_NULL(final_node);
196 abstract_list.push_back(final_node->abstract());
197 make_tuple_inputs.push_back(final_node);
198 }
199 MS_EXCEPTION_IF_NULL(func_graph);
200 CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
201 MS_EXCEPTION_IF_NULL(make_tuple);
202 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
203 return make_tuple;
204 }
205
DealRefSingleOutput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::shared_ptr<kernel::OpInfo> & op_info) const206 CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefSingleOutput(
207 const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
208 MS_EXCEPTION_IF_NULL(cnode);
209 MS_EXCEPTION_IF_NULL(op_info);
210 auto ref_infos = op_info->ref_infos();
211 if (ref_infos.empty()) {
212 return nullptr;
213 }
214 auto ref_info = *(ref_infos.begin());
215 if (ref_info.second > cnode->inputs().size()) {
216 MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() << ", ref info is "
217 << ref_info.second;
218 }
219 return AddAdditionalToRefOutput(func_graph, cnode, ref_info.first, ref_info.second, nullptr);
220 }
221
DefinePattern() const222 const BaseRef DealRefAndSpiltUnSupportedTransdata::DefinePattern() const {
223 VarPtr V = std::make_shared<CondVar>(UnVisited);
224 VarPtr Xs = std::make_shared<SeqVar>();
225 return VectorRef({V, Xs});
226 }
227
DealBroadCastAsRef(const FuncGraphPtr & func_graph,const CNodePtr & cnode) const228 void DealRefAndSpiltUnSupportedTransdata::DealBroadCastAsRef(const FuncGraphPtr &func_graph,
229 const CNodePtr &cnode) const {
230 if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
231 auto input_size = AnfAlgo::GetInputTensorNum(cnode);
232 for (size_t i = 0; i < input_size; ++i) {
233 auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i, true);
234 auto input_node = input_node_with_index.first;
235 MS_EXCEPTION_IF_NULL(input_node);
236 MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope();
237 AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index);
238 }
239 }
240 }
241
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const242 const AnfNodePtr DealRefAndSpiltUnSupportedTransdata::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
243 const EquivPtr &) const {
244 if (node == nullptr || !node->isa<CNode>()) {
245 return nullptr;
246 }
247 AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
248 auto cnode = node->cast<CNodePtr>();
249 MS_EXCEPTION_IF_NULL(cnode);
250 if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
251 return nullptr;
252 }
253
254 DealBroadCastAsRef(graph, cnode);
255
256 auto op_name = AnfAlgo::GetCNodeName(cnode);
257 auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
258 if (op_info == nullptr || !op_info->is_ref()) {
259 return nullptr;
260 }
261 if (op_info->is_ref()) {
262 auto type = cnode->Type();
263 MS_EXCEPTION_IF_NULL(type);
264 if (!type->isa<Tuple>()) {
265 return DealRefSingleOutput(graph, cnode, op_info);
266 } else {
267 return DealRefForMultipleOutput(graph, cnode, op_info);
268 }
269 }
270 return nullptr;
271 }
272
SplitTransdataIfNotSupported(const FuncGraphPtr & func_graph,const CNodePtr & cnode) const273 CNodePtr DealRefAndSpiltUnSupportedTransdata::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
274 const CNodePtr &cnode) const {
275 MS_EXCEPTION_IF_NULL(cnode);
276 auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
277 MS_EXCEPTION_IF_NULL(kernel_info);
278 // When the input and output format is only one special format just need to be splited into transpose and transdata
279 if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
280 kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
281 if (IsFormatInvaild(cnode)) {
282 return DoSplit(func_graph, cnode);
283 }
284 return cnode;
285 }
286 // When input and output format are all special format
287 // the node should be splited to two transdata connected by default format
288 auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
289 MS_EXCEPTION_IF_NULL(builder_info_to_default);
290 auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
291 MS_EXCEPTION_IF_NULL(builder_info_to_special_foramt);
292 builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
293 builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
294 std::vector<AnfNodePtr> next_trans_node_inputs = {
295 NewValueNode(std::make_shared<Primitive>(prim::kPrimTransData->name())), cnode};
296 MS_EXCEPTION_IF_NULL(func_graph);
297 auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
298 MS_EXCEPTION_IF_NULL(next_trans_node);
299 next_trans_node->set_abstract(cnode->abstract());
300 AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get());
301 AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
302 RefreshKernelBuildInfo(AnfAlgo::GetInputFormat(cnode, 0), kOpFormat_DEFAULT, cnode);
303 RefreshKernelBuildInfo(kOpFormat_DEFAULT, AnfAlgo::GetOutputFormat(next_trans_node, 0), next_trans_node);
304 if (IsFormatInvaild(cnode)) {
305 auto after_split_node = DoSplit(func_graph, cnode);
306 AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0);
307 }
308 if (IsFormatInvaild(next_trans_node)) {
309 return DoSplit(func_graph, next_trans_node);
310 }
311 return next_trans_node;
312 }
313 } // namespace opt
314 } // namespace mindspore
315