• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h"
17 #include <vector>
18 #include <utility>
19 #include <unordered_map>
20 #include <deque>
21 #include <memory>
22 #include <string>
23 #include <algorithm>
24 #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
25 #include "backend/kernel_compiler/tbe/ascend_kernel_compile.h"
26 #include "backend/kernel_compiler/kernel_fusion.h"
27 #include "debug/anf_ir_dump.h"
28 #include "backend/session/anf_runtime_algorithm.h"
29 #include "base/core_ops.h"
30 #include "runtime/device/kernel_info.h"
31 #include "utils/ms_context.h"
32 #include "backend/optimizer/common/helper.h"
33 
34 namespace mindspore {
35 namespace opt {
36 namespace {
37 const int8_t MAX_PATTERN_SIZE = 7;
38 const int8_t MIN_PATTERN_SIZE = 2;
39 const int8_t ELTWISE_INPUT_SIZE = 2;
40 const int8_t ELTWISE_USE = 1;
41 const int8_t MULTI_ELTWISE_USE = 2;
42 const int8_t MAX_MULTI_ELTWISE_SIZE = 4;
43 const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3;
44 constexpr size_t kFusionNodeNumThreshold = 2;
45 constexpr auto kOpAttrFusionId = "fusion_id";
46 
47 #ifdef DEBUG
DumpFusionScopeInfo(const kernel::FusionScopeInfo & info)48 void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) {
49   MS_LOG(INFO) << "=== Dump FusionScopeInfo start id: " << info.scope_id;
50   for (auto &node : info.input_nodes) {
51     MS_LOG(INFO) << "=== Input: " << node->DebugString();
52   }
53   for (auto &node : info.output_nodes) {
54     MS_LOG(INFO) << "=== Output: " << node->DebugString();
55   }
56   for (auto &node : info.compute_nodes) {
57     MS_LOG(INFO) << "=== Compute: (" << node->DebugString() << ")-("
58                  << mindspore::kekernel::tbe::GetFusionTypeName(AnfAlgo::GetFusionType(node)) << ")";
59   }
60   MS_LOG(INFO) << "=== Dump FusionScopeInfo end";
61 }
62 #endif
CreateFusionOp(const std::vector<AnfNodePtr> & inputs_list,const std::vector<AnfNodePtr> & outputs_list,const std::vector<AnfNodePtr> & anf_nodes,session::KernelGraph * kernel_graph)63 CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::vector<AnfNodePtr> &outputs_list,
64                         const std::vector<AnfNodePtr> &anf_nodes, session::KernelGraph *kernel_graph) {
65   MS_LOG(DEBUG) << "Start Create FusionOp Kernel";
66   MS_EXCEPTION_IF_NULL(kernel_graph);
67   std::string fusion_op_name = "FusionOp";
68   for (auto &node : anf_nodes) {
69     fusion_op_name += '_' + AnfAlgo::GetCNodeName(node);
70   }
71   auto fusion_op = std::make_shared<Primitive>(fusion_op_name);
72   MS_EXCEPTION_IF_NULL(fusion_op);
73 
74   std::vector<std::string> input_names;
75   for (size_t i = 0; i < inputs_list.size(); i++) {
76     (void)input_names.emplace_back("input" + std::to_string(i));
77   }
78   std::vector<std::string> output_names;
79   for (size_t i = 0; i < outputs_list.size(); i++) {
80     (void)output_names.emplace_back("output" + std::to_string(i));
81   }
82 
83   ValuePtr input_names_v = MakeValue(input_names);
84   ValuePtr output_names_v = MakeValue(output_names);
85   fusion_op->set_attr("input_names", input_names_v);
86   fusion_op->set_attr("output_names", output_names_v);
87   for (auto &node : anf_nodes) {
88     MS_EXCEPTION_IF_NULL(node);
89     auto cnode = node->cast<CNodePtr>();
90     if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
91       auto fracz_group = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFracZGroup);
92       fusion_op->set_attr(kAttrFracZGroup, MakeValue(fracz_group));
93       break;
94     }
95   }
96   std::vector<AnfNodePtr> fusion_inputs_list = inputs_list;
97   auto value_node = std::make_shared<ValueNode>(fusion_op);
98   (void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node);
99   auto buffer_fusion_kernel = kernel_graph->NewCNode(fusion_inputs_list);
100   if (buffer_fusion_kernel == nullptr) {
101     MS_LOG(EXCEPTION) << "New FusionOp kernel failed!";
102   }
103   buffer_fusion_kernel->set_scope((anf_nodes.back())->scope());
104 
105   return buffer_fusion_kernel;
106 }
107 
CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> & inputs_list,const std::vector<AnfNodePtr> & outputs_list)108 kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> &inputs_list,
109                                                     const std::vector<AnfNodePtr> &outputs_list) {
110   MS_LOG(DEBUG) << "Start Create Kernel Info";
111   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
112   // inputs format and data type
113   std::vector<std::string> inputs_format;
114   std::vector<TypeId> inputs_data_type;
115   for (const auto &input : inputs_list) {
116     auto real_input = AnfAlgo::VisitKernel(input, 0);
117     (void)inputs_format.emplace_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second));
118     (void)inputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second));
119   }
120   // outputs format and data type
121   std::vector<std::string> outputs_format;
122   std::vector<TypeId> outputs_data_type;
123   for (const auto &output : outputs_list) {
124     if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
125       auto tuple_getitem = output->cast<CNodePtr>();
126       MS_EXCEPTION_IF_NULL(tuple_getitem);
127       (void)outputs_format.emplace_back(AnfAlgo::GetOutputFormat(
128         tuple_getitem->input(kIndex1), LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kIndex2))))));
129       (void)outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(
130         tuple_getitem->input(kIndex1), LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kIndex2))))));
131     } else {
132       (void)outputs_format.emplace_back(AnfAlgo::GetOutputFormat(output, 0));
133       (void)outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(output, 0));
134     }
135   }
136   builder.SetInputsFormat(inputs_format);
137   builder.SetInputsDeviceType(inputs_data_type);
138   builder.SetOutputsFormat(outputs_format);
139   builder.SetOutputsDeviceType(outputs_data_type);
140   builder.SetKernelType(KernelType::TBE_KERNEL);
141   return builder.Build();
142 }
143 
CreateTupleGetItem(const AnfNodePtr & buffer_fusion_kernel,session::KernelGraph * kernel_graph,size_t output_index)144 AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph,
145                               size_t output_index) {
146   MS_EXCEPTION_IF_NULL(kernel_graph);
147   std::vector<AnfNodePtr> tuple_getitem_inputs_list;
148   auto value = std::make_shared<ValueNode>(prim::kPrimTupleGetItem);
149   MS_EXCEPTION_IF_NULL(value);
150   auto idx = NewValueNode(SizeToLong(output_index));
151   MS_EXCEPTION_IF_NULL(idx);
152   int64_t temp = SizeToLong(output_index);
153   auto imm = std::make_shared<Int64Imm>(temp);
154   auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
155   idx->set_abstract(abstract_scalar);
156   tuple_getitem_inputs_list.push_back(value);
157   tuple_getitem_inputs_list.push_back(buffer_fusion_kernel);
158   tuple_getitem_inputs_list.push_back(idx);
159   auto tuple_item = kernel_graph->NewCNode(tuple_getitem_inputs_list);
160   MS_EXCEPTION_IF_NULL(tuple_item);
161   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)},
162                                       {AnfAlgo::GetOutputInferShape(buffer_fusion_kernel, output_index)},
163                                       tuple_item.get());
164   return tuple_item;
165 }
166 
ReplaceInputNodeInOtherFusionScope(std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos,int64_t fusion_id,const AnfNodePtr & output_item,const AnfNodePtr & replace_item)167 void ReplaceInputNodeInOtherFusionScope(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos,
168                                         int64_t fusion_id, const AnfNodePtr &output_item,
169                                         const AnfNodePtr &replace_item) {
170   for (int64_t id = fusion_id + 1; id <= SizeToLong(buffer_fusion_infos->size()); ++id) {
171     auto itr = std::find((*buffer_fusion_infos)[id].inputs_list.begin(), (*buffer_fusion_infos)[id].inputs_list.end(),
172                          output_item);
173     if (itr != (*buffer_fusion_infos)[id].inputs_list.end()) {
174       MS_LOG(DEBUG) << "replace input of other pattern, id = " << id;
175       *itr = replace_item;
176     }
177   }
178 }
179 
ReplaceOldNode(std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos,int64_t fusion_id,const AnfNodePtr & buffer_fusion_kernel,session::KernelGraph * kernel_graph)180 void ReplaceOldNode(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos, int64_t fusion_id,
181                     const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) {
182   MS_EXCEPTION_IF_NULL(kernel_graph);
183   MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
184   auto manager = kernel_graph->manager();
185   MS_EXCEPTION_IF_NULL(manager);
186   auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
187   if (buffer_fusion_info.outputs_list.size() == 1) {  // single output
188     if (kernel_graph != nullptr) {
189       kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
190     }
191     (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
192     ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0],
193                                        buffer_fusion_kernel);
194   } else {  // multiple output
195     for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) {
196       auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index);
197       if (kernel_graph != nullptr) {
198         kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[index], tuple_item);
199       }
200       (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item);
201       ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index],
202                                          tuple_item);
203     }
204   }
205 }
206 
GetFusionScopeComputeNodeList(session::KernelGraph * kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos)207 void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
208                                    std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
209   MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
210   MS_EXCEPTION_IF_NULL(kernel_graph);
211   auto nodes = TopoSort(kernel_graph->get_return());
212   for (auto &node : nodes) {
213     MS_EXCEPTION_IF_NULL(node);
214     if (!node->isa<CNode>()) {
215       continue;
216     }
217     auto cnode = node->cast<CNodePtr>();
218     if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) {
219       auto fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kOpAttrFusionId);
220       (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode);
221     }
222   }
223 }
224 
GetFusionScopeInputNodeList(const session::KernelGraph & kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos)225 void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
226                                  std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
227   MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
228   auto manager = kernel_graph.manager();
229   MS_EXCEPTION_IF_NULL(manager);
230 
231   for (auto &buffer_fusion_info : *buffer_fusion_infos) {
232     auto fusion_id = buffer_fusion_info.first;
233     const auto &fusion_info = buffer_fusion_info.second;
234     for (const auto &node : fusion_info.anf_nodes) {
235       auto cnode = node->cast<CNodePtr>();
236       MS_EXCEPTION_IF_NULL(cnode);
237       for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) {
238         auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0);
239         if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) ==
240             fusion_info.anf_nodes.end()) {
241           if (!HasAbstractMonad(cnode->input(idx))) {
242             (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx));
243           }
244         }
245       }
246     }
247   }
248 }
249 
TupleGetitemNodeCompare(const AnfNodePtr & node1,const AnfNodePtr & node2)250 bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) {
251   MS_EXCEPTION_IF_NULL(node1);
252   MS_EXCEPTION_IF_NULL(node2);
253   auto getitem1 = node1->cast<CNodePtr>();
254   auto getitem2 = node2->cast<CNodePtr>();
255   MS_EXCEPTION_IF_NULL(getitem1);
256   MS_EXCEPTION_IF_NULL(getitem2);
257   if (getitem1->size() < kTupleGetItemInputSize) {
258     MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
259                       << getitem1->DebugString() << "]";
260   }
261   if (getitem2->size() < kTupleGetItemInputSize) {
262     MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
263                       << getitem2->DebugString() << "]";
264   }
265   auto output_idx1 = GetValue<int64_t>(GetValueNode(getitem1->input(kIndex2)));
266   auto output_idx2 = GetValue<int64_t>(GetValueNode(getitem2->input(kIndex2)));
267   return output_idx1 < output_idx2;
268 }
269 
RemoveNodeFromUpdateState(session::KernelGraph * kernel_graph,const AnfNodePtr & node,const AnfNodePtr & updatestate)270 AnfNodePtr RemoveNodeFromUpdateState(session::KernelGraph *kernel_graph, const AnfNodePtr &node,
271                                      const AnfNodePtr &updatestate) {
272   MS_EXCEPTION_IF_NULL(kernel_graph);
273   MS_EXCEPTION_IF_NULL(node);
274   MS_EXCEPTION_IF_NULL(updatestate);
275   auto updatestate_cnode = updatestate->cast<CNodePtr>();
276   auto inputs = updatestate_cnode->inputs();
277   std::vector<AnfNodePtr> new_inputs;
278   (void)std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
279                      [node](const AnfNodePtr &input) { return node != input; });
280   auto new_updatestate = kernel_graph->NewCNode(new_inputs);
281   new_updatestate->set_scope(updatestate->scope());
282   new_updatestate->set_abstract(updatestate->abstract());
283   return new_updatestate;
284 }
285 
GetFusionScopeOutputNodeList(session::KernelGraph * kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos)286 void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
287                                   std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
288   MS_EXCEPTION_IF_NULL(kernel_graph);
289   MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
290   auto manager = kernel_graph->manager();
291   MS_EXCEPTION_IF_NULL(manager);
292 
293   for (auto &buffer_fusion_info : *buffer_fusion_infos) {
294     auto fusion_id = buffer_fusion_info.first;
295     const auto &fusion_info = buffer_fusion_info.second;
296     for (const auto &node : fusion_info.anf_nodes) {
297       if (AnfAlgo::GetOutputTensorNum(node) == 1) {
298         auto use_nodes = manager->node_users()[node];
299         for (auto use_node : use_nodes) {
300           // Do not think of updatestate as real output,
301           // Ensuring normal fusion requires eliminating the node of the updatestate
302           if (AnfAlgo::CheckPrimitiveType(use_node.first, prim::kPrimUpdateState)) {
303             auto new_updatestate = RemoveNodeFromUpdateState(kernel_graph, node, use_node.first);
304             (void)manager->Replace(use_node.first, new_updatestate);
305             continue;
306           }
307           if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), use_node.first) ==
308               fusion_info.anf_nodes.end()) {
309             (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(node);
310             break;
311           }
312         }
313       } else {
314         int64_t prev_idx = 0;
315         std::vector<AnfNodePtr> tuple_getitem_nodes;
316         auto users = manager->node_users()[node];
317         for (auto &user : users) {
318           if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) {
319             auto new_updatestate = RemoveNodeFromUpdateState(kernel_graph, node, user.first);
320             (void)manager->Replace(user.first, new_updatestate);
321             continue;
322           }
323           if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimTupleGetItem)) {
324             (void)tuple_getitem_nodes.emplace_back(user.first);
325           }
326         }
327         std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare);
328         for (auto &getitem : tuple_getitem_nodes) {
329           MS_EXCEPTION_IF_NULL(getitem);
330           auto getitem_ptr = getitem->cast<CNodePtr>();
331           MS_EXCEPTION_IF_NULL(getitem_ptr);
332           auto input2 = getitem_ptr->input(kIndex2);
333           auto output_idx = GetValue<int64_t>(GetValueNode(input2));
334           for (int64_t stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) {
335             auto stub_node = CreateTupleGetItem(node, kernel_graph, LongToSize(stub_idx));
336             (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node);
337           }
338           prev_idx = output_idx + 1;
339           for (auto &item_use_node : manager->node_users()[getitem]) {
340             if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) ==
341                 fusion_info.anf_nodes.end()) {
342               (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem);
343               break;
344             }
345           }
346         }
347       }
348     }
349   }
350 }
351 
SetOutputUsedNumAttr(const session::KernelGraph & kernel_graph,const std::unordered_map<int64_t,BufferFusionInfo_t> & buffer_fusion_infos)352 void SetOutputUsedNumAttr(const session::KernelGraph &kernel_graph,
353                           const std::unordered_map<int64_t, BufferFusionInfo_t> &buffer_fusion_infos) {
354   for (auto &fusion_info : buffer_fusion_infos) {
355     auto &fusion_nodes = fusion_info.second.anf_nodes;
356     for (auto iter = fusion_nodes.begin(); iter != fusion_nodes.end() - 1; ++iter) {
357       auto node = *iter;
358       auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
359       AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), node);
360     }
361   }
362 }
363 
SetFusionOpRefInfos(session::KernelGraph * kernel_graph,const std::vector<AnfNodePtr> & outputs_list,const AnfNodePtr & fusion_kernel)364 void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<AnfNodePtr> &outputs_list,
365                          const AnfNodePtr &fusion_kernel) {
366   MS_EXCEPTION_IF_NULL(kernel_graph);
367   auto manager = kernel_graph->manager();
368   MS_EXCEPTION_IF_NULL(manager);
369   for (size_t idx = 0; idx < outputs_list.size(); ++idx) {
370     auto output = outputs_list[idx];
371     MS_EXCEPTION_IF_NULL(output);
372     if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
373       auto real_output = AnfAlgo::VisitKernel(output, 0);
374       auto output_cnode = output->cast<CNodePtr>();
375       MS_EXCEPTION_IF_NULL(output_cnode);
376       auto input2 = output_cnode->input(kIndex2);
377       auto output_idx = GetValue<int64_t>(GetValueNode(input2));
378       session::AnfWithOutIndex out_pair(real_output.first, output_idx);
379       if (kernel_graph->IsInRefOutputMap(out_pair)) {
380         auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
381         session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
382         kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
383       }
384     } else {
385       session::AnfWithOutIndex out_pair(output, 0);
386       if (kernel_graph->IsInRefOutputMap(out_pair)) {
387         auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
388         session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
389         kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
390       }
391     }
392   }
393 }
394 
CheckCircle(const session::KernelGraph & kernel_graph,const BufferFusionInfo_t & fusion_info)395 bool CheckCircle(const session::KernelGraph &kernel_graph, const BufferFusionInfo_t &fusion_info) {
396   bool has_circle = false;
397   for (auto &inp : fusion_info.inputs_list) {
398     MS_EXCEPTION_IF_NULL(inp);
399     if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) {
400       continue;
401     }
402 
403     if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
404       has_circle = true;
405       break;
406     }
407   }
408   return has_circle;
409 }
410 
RemoveCircle(const session::KernelGraph & kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos)411 void RemoveCircle(const session::KernelGraph &kernel_graph,
412                   std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
413   MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
414   std::vector<int64_t> fusion_ids;
415   for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) {
416     bool has_circle = CheckCircle(kernel_graph, fusion_info);
417     if (has_circle) {
418       (void)fusion_ids.emplace_back(fusion_id);
419     }
420   }
421 
422   for (auto &fusion_id : fusion_ids) {
423     buffer_fusion_infos->erase(fusion_id);
424   }
425 }
426 }  // namespace
427 
GetBufferFusionInfo(session::KernelGraph * kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos) const428 void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
429                                           std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
430   MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
431   MS_EXCEPTION_IF_NULL(kernel_graph);
432   GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
433   GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos);
434   GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
435   // Remove the fusion infos which will produce a circle if do fusion
436   RemoveCircle(*kernel_graph, buffer_fusion_infos);
437   SetOutputUsedNumAttr(*kernel_graph, *buffer_fusion_infos);
438 
439   for (auto &buffer_fusion_info : *buffer_fusion_infos) {
440     buffer_fusion_info.second.kernel_build_info =
441       CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list);
442     // just for full_name_with_scope for every buffer_fusion_info.
443     auto fusion_node = CreateFusionOp(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list,
444                                       buffer_fusion_info.second.anf_nodes, kernel_graph);
445     MS_EXCEPTION_IF_NULL(fusion_node);
446     buffer_fusion_info.second.full_name = fusion_node->fullname_with_scope();
447   }
448 }
449 
FuseBufferFusionPattern(session::KernelGraph * kernel_graph) const450 bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const {
451   MS_EXCEPTION_IF_NULL(kernel_graph);
452   bool change = false;
453   std::unordered_map<int64_t, BufferFusionInfo_t> buffer_fusion_infos;
454   GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos);
455 
456   std::vector<mindspore::kernel::FusionScopeInfo> fusion_scope_infos;
457   std::transform(
458     buffer_fusion_infos.begin(), buffer_fusion_infos.end(), std::back_inserter(fusion_scope_infos),
459     [](const std::pair<int64_t, BufferFusionInfo_t> &buffer_fusion_info) -> mindspore::kernel::FusionScopeInfo {
460       return mindspore::kernel::FusionScopeInfo(
461         buffer_fusion_info.first, buffer_fusion_info.second.full_name, buffer_fusion_info.second.inputs_list,
462         buffer_fusion_info.second.anf_nodes, buffer_fusion_info.second.outputs_list);
463     });
464   std::map<int64_t, kernel::KernelModPtr> kernel_mods;
465   std::string old_build = common::GetEnv("MS_OLD_BUILD_PROCESS");
466   if (!old_build.empty()) {
467     kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos);
468   } else if (!fusion_scope_infos.empty()) {
469     auto &build_manager = kernel::ascend::AscendKernelCompileManager::GetInstance();
470     kernel_mods = build_manager.AscendFusionOpCompile(fusion_scope_infos);
471     build_manager.ResetOldTask();
472   }
473   std::set<int64_t> fusion_ids;
474   for (auto &buffer_fusion_info : buffer_fusion_infos) {
475     MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size()
476                   << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size()
477                   << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size();
478     fusion_ids.insert(buffer_fusion_info.first);
479   }
480   // Replace fusion op from return to head
481   for (auto &fusion_id : fusion_ids) {
482     // Get kernel mod when supporting tbe
483     if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) {
484       MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed";
485       continue;
486     }
487     if (CheckCircle(*kernel_graph, buffer_fusion_infos[fusion_id])) {
488       MS_LOG(DEBUG) << "fusion id: " << fusion_id << " will cause graph circle, pass this fusion.";
489     } else {
490       change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph);
491     }
492   }
493   MS_LOG(DEBUG) << "End Buffer Fusion";
494   return change;
495 }
496 
ReplaceFusionOp(std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos,int64_t fusion_id,const kernel::KernelModPtr & kernel_ptr,session::KernelGraph * kernel_graph) const497 bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos,
498                                       int64_t fusion_id, const kernel::KernelModPtr &kernel_ptr,
499                                       session::KernelGraph *kernel_graph) const {
500   MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
501   auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
502   if (buffer_fusion_info.anf_nodes.size() < kFusionNodeNumThreshold) {
503     return false;
504   }
505   TraceGuard guard(std::make_shared<TraceOpt>(buffer_fusion_info.anf_nodes[0]->debug_info()));
506   auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list,
507                                       buffer_fusion_info.anf_nodes, kernel_graph);
508   buffer_fusion->set_fullname_with_scope(buffer_fusion_info.full_name);
509   AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get());
510   // Set abstract of fusion_op node
511   std::vector<TypeId> types;
512   std::vector<std::vector<size_t>> shapes;
513   for (const auto &out_node : buffer_fusion_info.outputs_list) {
514     size_t out_num = AnfAlgo::GetOutputTensorNum(out_node);
515     for (size_t idx = 0; idx < out_num; ++idx) {
516       (void)types.emplace_back(AnfAlgo::GetOutputInferDataType(out_node, idx));
517       (void)shapes.emplace_back(AnfAlgo::GetOutputInferShape(out_node, idx));
518     }
519   }
520   if (types.empty() || shapes.empty()) {
521     MS_LOG(WARNING) << "buffer_fusion_info.outputs_list is empty";
522     return false;
523   }
524   AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get());
525   AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get());
526   SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion);
527   ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph);
528   return true;
529 }
530 
Run(const FuncGraphPtr & graph)531 bool UbPatternFusion::Run(const FuncGraphPtr &graph) {
532   bool changed = false;
533   MS_EXCEPTION_IF_NULL(graph);
534   auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
535   MS_EXCEPTION_IF_NULL(kernel_graph);
536   changed = FuseBufferFusionPattern(kernel_graph.get());
537   // clear fusion_id attr
538   for (auto &node : graph->nodes()) {
539     if (node != nullptr && node->isa<CNode>()) {
540       AnfAlgo::EraseNodeAttr(kAttrFusionId, node);
541     }
542   }
543   return changed;
544 }
545 }  // namespace opt
546 }  // namespace mindspore
547