• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h"
18 #include <utility>
19 #include <vector>
20 #include <memory>
21 #include <string>
22 #include "backend/kernel_compiler/oplib/oplib.h"
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "backend/session/kernel_graph.h"
25 #include "backend/optimizer/common/helper.h"
26 
27 namespace mindspore {
28 namespace opt {
FindRefOriginNode(const AnfNodePtr & node) const29 session::KernelWithIndex DealRefAndSpiltUnSupportedTransdata::FindRefOriginNode(const AnfNodePtr &node) const {
30   session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0);
31   AnfNodePtr cur_node = kernel_with_index.first;
32   size_t cur_out_index = kernel_with_index.second;
33   MS_EXCEPTION_IF_NULL(cur_node);
34   if (cur_node->isa<CNode>()) {
35     auto cnode = cur_node->cast<CNodePtr>();
36     MS_EXCEPTION_IF_NULL(cnode);
37     std::string op_name = AnfAlgo::GetCNodeName(cnode);
38     auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
39     // deal ref op
40     if (op_info != nullptr && op_info->is_ref()) {
41       auto ref_infos = op_info->ref_infos();
42       if (ref_infos.count(cur_out_index) != 0) {
43         auto in_index = ref_infos.at(cur_out_index);
44         if (in_index > cnode->inputs().size()) {
45           MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size()
46                             << ", ref info is " << cur_out_index;
47         }
48         AnfNodePtr next_node = cnode->input(in_index + 1);
49         return FindRefOriginNode(next_node);
50       }
51     }
52 
53     // deal special (trans,cast,reshape) op and nop-node
54     if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() ||
55         op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName || opt::IsNopNode(cnode)) {
56       AnfNodePtr next_node = cnode->input(1);
57       return FindRefOriginNode(next_node);
58     }
59   }
60 
61   return kernel_with_index;
62 }
63 
AddRefNodePairToKernelGraph(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const size_t output_index,const size_t input_index) const64 void DealRefAndSpiltUnSupportedTransdata::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph,
65                                                                       const CNodePtr &cnode, const size_t output_index,
66                                                                       const size_t input_index) const {
67   // record the ref_pair
68   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
69   MS_EXCEPTION_IF_NULL(kernel_graph);
70   session::AnfWithOutIndex final_pair = std::make_pair(cnode, output_index);
71   session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0);
72   kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index);
73 }
74 
AddRefPairToKernelGraph(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const AnfNodePtr & get_item,const AnfNodePtr & final_node,size_t final_index,const session::KernelWithIndex & origin_pair) const75 void DealRefAndSpiltUnSupportedTransdata::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
76                                                                   const AnfNodePtr &get_item,
77                                                                   const AnfNodePtr &final_node, size_t final_index,
78                                                                   const session::KernelWithIndex &origin_pair) const {
79   // record the ref_pair
80   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
81   MS_EXCEPTION_IF_NULL(kernel_graph);
82   // if the final node is get item, means no trans or cast op is added, the final node is itself
83   // so add the pair for itself, because the get item will removed later
84   auto final_ref = (final_node == get_item ? cnode : final_node);
85   session::AnfWithOutIndex final_pair = std::make_pair(final_ref, final_index);
86   if (kernel_graph->IsInRefOutputMap(final_pair)) {
87     MS_LOG(EXCEPTION) << "ref_pair is already in ref map, node is " << final_ref->DebugString() << ", index is "
88                       << final_index;
89   }
90   MS_LOG(DEBUG) << "Add Ref pair, final {node ptr " << final_pair.first.get() << " , info is "
91                 << final_pair.first->DebugString() << " , index is " << final_pair.second << "}, origin {node ptr "
92                 << origin_pair.first.get() << ", info is " << origin_pair.first->DebugString() << " : index "
93                 << origin_pair.second << "}";
94   kernel_graph->AddRefCorrespondPairs(final_pair, origin_pair);
95 }
96 
97 // if get_item is nullptr, the additional node will link to the cnode
98 // else the additional node will link to the get_item node (the get_item node link to cnode)
AddAdditionalToRefOutput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t output_index,size_t input_index,const CNodePtr & get_item) const99 CNodePtr DealRefAndSpiltUnSupportedTransdata::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph,
100                                                                        const CNodePtr &cnode, size_t output_index,
101                                                                        size_t input_index,
102                                                                        const CNodePtr &get_item) const {
103   CNodePtr final_node = (get_item == nullptr ? cnode : get_item);
104   bool need_refresh_ref_addr = false;
105   size_t final_index = output_index;
106   AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index);
107   session::KernelWithIndex origin_pair = FindRefOriginNode(input_node);
108   MS_EXCEPTION_IF_NULL(origin_pair.first);
109   if (!origin_pair.first->isa<Parameter>()) {
110     MS_LOG(WARNING) << "ref op origin node is not parameter";
111   }
112   MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is "
113                 << origin_pair.first->DebugString() << ", index is " << origin_pair.second;
114   auto origin_format = AnfAlgo::GetOutputFormat(origin_pair.first, origin_pair.second);
115   auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second);
116   auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index);
117   auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index);
118   auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
119   auto detail_shape = AnfAlgo::GetOutputDetailShape(cnode, output_index);
120   // insert trans
121   if (origin_format != cur_format && cur_shape.size() > 1) {
122     auto kernel_select = std::make_shared<KernelSelect>();
123     final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::kPrimTransData->name());
124     RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
125     final_node = SplitTransdataIfNotSupported(func_graph, final_node);
126     final_index = 0;
127     need_refresh_ref_addr = true;
128     MS_EXCEPTION_IF_NULL(final_node);
129     MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
130   }
131   // insert cast
132   if (origin_type != cur_type) {
133     final_node =
134       AddCastOpNodeToGraph(func_graph, final_node, origin_format, cur_type, origin_type, detail_shape, cur_type);
135     MS_EXCEPTION_IF_NULL(final_node);
136     final_node->set_scope(cnode->scope());
137     final_index = 0;
138     need_refresh_ref_addr = true;
139     MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString();
140   }
141   // add ref pair
142   AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair);
143   if (need_refresh_ref_addr) {
144     AddRefNodePairToKernelGraph(func_graph, cnode, output_index, input_index);
145   }
146   // insert depend
147   if (origin_format != cur_format || origin_type != cur_type) {
148     final_node = MakeDependency(get_item, final_node, cnode, func_graph);
149     MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString();
150   }
151   return final_node;
152 }
153 
MakeDependency(const CNodePtr & get_item,const CNodePtr & final_node,const CNodePtr & cnode,const FuncGraphPtr & func_graph) const154 CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
155                                                              const CNodePtr &cnode,
156                                                              const FuncGraphPtr &func_graph) const {
157   std::vector<AnfNodePtr> depend_nodes;
158   if (get_item != nullptr) {
159     depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
160   } else {
161     depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node};
162   }
163   return func_graph->NewCNode(depend_nodes);
164 }
DealRefForMultipleOutput(const FuncGraphPtr & func_graph,const CNodePtr & orig_cnode,const std::shared_ptr<kernel::OpInfo> & op_info) const165 CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
166   const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
167   MS_EXCEPTION_IF_NULL(func_graph);
168   auto manager = func_graph->manager();
169   MS_EXCEPTION_IF_NULL(manager);
170   auto cnode = orig_cnode;
171   auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode);
172   if (!update_states.empty()) {
173     auto kernel_graph = func_graph->cast<KernelGraphPtr>();
174     MS_EXCEPTION_IF_NULL(kernel_graph);
175     cnode = kernel_graph->NewCNode(orig_cnode);
176     MS_EXCEPTION_IF_NULL(cnode);
177     cnode->set_inputs(orig_cnode->inputs());
178     for (auto &update_state : update_states) {
179       manager->SetEdge(update_state.first, update_state.second, cnode);
180     }
181   }
182   MS_EXCEPTION_IF_NULL(op_info);
183   auto ref_infos = op_info->ref_infos();
184   std::vector<AnfNodePtr> make_tuple_inputs;
185   AbstractBasePtrList abstract_list;
186   (void)make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
187   size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
188   for (size_t output_index = 0; output_index < output_num; ++output_index) {
189     CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index);
190     // deal with ref output
191     if (ref_infos.count(output_index) != 0) {
192       auto input_index = ref_infos.at(output_index);
193       final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node);
194     }
195     MS_EXCEPTION_IF_NULL(final_node);
196     abstract_list.push_back(final_node->abstract());
197     make_tuple_inputs.push_back(final_node);
198   }
199   MS_EXCEPTION_IF_NULL(func_graph);
200   CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
201   MS_EXCEPTION_IF_NULL(make_tuple);
202   make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
203   return make_tuple;
204 }
205 
DealRefSingleOutput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::shared_ptr<kernel::OpInfo> & op_info) const206 CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefSingleOutput(
207   const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
208   MS_EXCEPTION_IF_NULL(cnode);
209   MS_EXCEPTION_IF_NULL(op_info);
210   auto ref_infos = op_info->ref_infos();
211   if (ref_infos.empty()) {
212     return nullptr;
213   }
214   auto ref_info = *(ref_infos.begin());
215   if (ref_info.second > cnode->inputs().size()) {
216     MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() << ", ref info is "
217                       << ref_info.second;
218   }
219   return AddAdditionalToRefOutput(func_graph, cnode, ref_info.first, ref_info.second, nullptr);
220 }
221 
DefinePattern() const222 const BaseRef DealRefAndSpiltUnSupportedTransdata::DefinePattern() const {
223   VarPtr V = std::make_shared<CondVar>(UnVisited);
224   VarPtr Xs = std::make_shared<SeqVar>();
225   return VectorRef({V, Xs});
226 }
227 
DealBroadCastAsRef(const FuncGraphPtr & func_graph,const CNodePtr & cnode) const228 void DealRefAndSpiltUnSupportedTransdata::DealBroadCastAsRef(const FuncGraphPtr &func_graph,
229                                                              const CNodePtr &cnode) const {
230   if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
231     auto input_size = AnfAlgo::GetInputTensorNum(cnode);
232     for (size_t i = 0; i < input_size; ++i) {
233       auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i, true);
234       auto input_node = input_node_with_index.first;
235       MS_EXCEPTION_IF_NULL(input_node);
236       MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope();
237       AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index);
238     }
239   }
240 }
241 
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const242 const AnfNodePtr DealRefAndSpiltUnSupportedTransdata::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
243                                                               const EquivPtr &) const {
244   if (node == nullptr || !node->isa<CNode>()) {
245     return nullptr;
246   }
247   AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
248   auto cnode = node->cast<CNodePtr>();
249   MS_EXCEPTION_IF_NULL(cnode);
250   if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
251     return nullptr;
252   }
253 
254   DealBroadCastAsRef(graph, cnode);
255 
256   auto op_name = AnfAlgo::GetCNodeName(cnode);
257   auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
258   if (op_info == nullptr || !op_info->is_ref()) {
259     return nullptr;
260   }
261   if (op_info->is_ref()) {
262     auto type = cnode->Type();
263     MS_EXCEPTION_IF_NULL(type);
264     if (!type->isa<Tuple>()) {
265       return DealRefSingleOutput(graph, cnode, op_info);
266     } else {
267       return DealRefForMultipleOutput(graph, cnode, op_info);
268     }
269   }
270   return nullptr;
271 }
272 
SplitTransdataIfNotSupported(const FuncGraphPtr & func_graph,const CNodePtr & cnode) const273 CNodePtr DealRefAndSpiltUnSupportedTransdata::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
274                                                                            const CNodePtr &cnode) const {
275   MS_EXCEPTION_IF_NULL(cnode);
276   auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
277   MS_EXCEPTION_IF_NULL(kernel_info);
278   // When the input and output format is only one special format just need to be splited into transpose and transdata
279   if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
280       kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
281     if (IsFormatInvaild(cnode)) {
282       return DoSplit(func_graph, cnode);
283     }
284     return cnode;
285   }
286   // When input and output format are all special format
287   // the node should be splited to two transdata connected by default format
288   auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
289   MS_EXCEPTION_IF_NULL(builder_info_to_default);
290   auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
291   MS_EXCEPTION_IF_NULL(builder_info_to_special_foramt);
292   builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
293   builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
294   std::vector<AnfNodePtr> next_trans_node_inputs = {
295     NewValueNode(std::make_shared<Primitive>(prim::kPrimTransData->name())), cnode};
296   MS_EXCEPTION_IF_NULL(func_graph);
297   auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
298   MS_EXCEPTION_IF_NULL(next_trans_node);
299   next_trans_node->set_abstract(cnode->abstract());
300   AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get());
301   AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
302   RefreshKernelBuildInfo(AnfAlgo::GetInputFormat(cnode, 0), kOpFormat_DEFAULT, cnode);
303   RefreshKernelBuildInfo(kOpFormat_DEFAULT, AnfAlgo::GetOutputFormat(next_trans_node, 0), next_trans_node);
304   if (IsFormatInvaild(cnode)) {
305     auto after_split_node = DoSplit(func_graph, cnode);
306     AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0);
307   }
308   if (IsFormatInvaild(next_trans_node)) {
309     return DoSplit(func_graph, next_trans_node);
310   }
311   return next_trans_node;
312 }
313 }  // namespace opt
314 }  // namespace mindspore
315