1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h"
17 #include <vector>
18 #include <utility>
19 #include <unordered_map>
20 #include <deque>
21 #include <memory>
22 #include <string>
23 #include <algorithm>
24 #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
25 #include "backend/kernel_compiler/tbe/ascend_kernel_compile.h"
26 #include "backend/kernel_compiler/kernel_fusion.h"
27 #include "debug/anf_ir_dump.h"
28 #include "backend/session/anf_runtime_algorithm.h"
29 #include "base/core_ops.h"
30 #include "runtime/device/kernel_info.h"
31 #include "utils/ms_context.h"
32 #include "backend/optimizer/common/helper.h"
33
34 namespace mindspore {
35 namespace opt {
36 namespace {
37 const int8_t MAX_PATTERN_SIZE = 7;
38 const int8_t MIN_PATTERN_SIZE = 2;
39 const int8_t ELTWISE_INPUT_SIZE = 2;
40 const int8_t ELTWISE_USE = 1;
41 const int8_t MULTI_ELTWISE_USE = 2;
42 const int8_t MAX_MULTI_ELTWISE_SIZE = 4;
43 const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3;
44 constexpr size_t kFusionNodeNumThreshold = 2;
45 constexpr auto kOpAttrFusionId = "fusion_id";
46
47 #ifdef DEBUG
DumpFusionScopeInfo(const kernel::FusionScopeInfo & info)48 void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) {
49 MS_LOG(INFO) << "=== Dump FusionScopeInfo start id: " << info.scope_id;
50 for (auto &node : info.input_nodes) {
51 MS_LOG(INFO) << "=== Input: " << node->DebugString();
52 }
53 for (auto &node : info.output_nodes) {
54 MS_LOG(INFO) << "=== Output: " << node->DebugString();
55 }
56 for (auto &node : info.compute_nodes) {
57 MS_LOG(INFO) << "=== Compute: (" << node->DebugString() << ")-("
58 << mindspore::kekernel::tbe::GetFusionTypeName(AnfAlgo::GetFusionType(node)) << ")";
59 }
60 MS_LOG(INFO) << "=== Dump FusionScopeInfo end";
61 }
62 #endif
CreateFusionOp(const std::vector<AnfNodePtr> & inputs_list,const std::vector<AnfNodePtr> & outputs_list,const std::vector<AnfNodePtr> & anf_nodes,session::KernelGraph * kernel_graph)63 CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::vector<AnfNodePtr> &outputs_list,
64 const std::vector<AnfNodePtr> &anf_nodes, session::KernelGraph *kernel_graph) {
65 MS_LOG(DEBUG) << "Start Create FusionOp Kernel";
66 MS_EXCEPTION_IF_NULL(kernel_graph);
67 std::string fusion_op_name = "FusionOp";
68 for (auto &node : anf_nodes) {
69 fusion_op_name += '_' + AnfAlgo::GetCNodeName(node);
70 }
71 auto fusion_op = std::make_shared<Primitive>(fusion_op_name);
72 MS_EXCEPTION_IF_NULL(fusion_op);
73
74 std::vector<std::string> input_names;
75 for (size_t i = 0; i < inputs_list.size(); i++) {
76 (void)input_names.emplace_back("input" + std::to_string(i));
77 }
78 std::vector<std::string> output_names;
79 for (size_t i = 0; i < outputs_list.size(); i++) {
80 (void)output_names.emplace_back("output" + std::to_string(i));
81 }
82
83 ValuePtr input_names_v = MakeValue(input_names);
84 ValuePtr output_names_v = MakeValue(output_names);
85 fusion_op->set_attr("input_names", input_names_v);
86 fusion_op->set_attr("output_names", output_names_v);
87 for (auto &node : anf_nodes) {
88 MS_EXCEPTION_IF_NULL(node);
89 auto cnode = node->cast<CNodePtr>();
90 if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
91 auto fracz_group = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFracZGroup);
92 fusion_op->set_attr(kAttrFracZGroup, MakeValue(fracz_group));
93 break;
94 }
95 }
96 std::vector<AnfNodePtr> fusion_inputs_list = inputs_list;
97 auto value_node = std::make_shared<ValueNode>(fusion_op);
98 (void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node);
99 auto buffer_fusion_kernel = kernel_graph->NewCNode(fusion_inputs_list);
100 if (buffer_fusion_kernel == nullptr) {
101 MS_LOG(EXCEPTION) << "New FusionOp kernel failed!";
102 }
103 buffer_fusion_kernel->set_scope((anf_nodes.back())->scope());
104
105 return buffer_fusion_kernel;
106 }
107
CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> & inputs_list,const std::vector<AnfNodePtr> & outputs_list)108 kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> &inputs_list,
109 const std::vector<AnfNodePtr> &outputs_list) {
110 MS_LOG(DEBUG) << "Start Create Kernel Info";
111 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
112 // inputs format and data type
113 std::vector<std::string> inputs_format;
114 std::vector<TypeId> inputs_data_type;
115 for (const auto &input : inputs_list) {
116 auto real_input = AnfAlgo::VisitKernel(input, 0);
117 (void)inputs_format.emplace_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second));
118 (void)inputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second));
119 }
120 // outputs format and data type
121 std::vector<std::string> outputs_format;
122 std::vector<TypeId> outputs_data_type;
123 for (const auto &output : outputs_list) {
124 if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
125 auto tuple_getitem = output->cast<CNodePtr>();
126 MS_EXCEPTION_IF_NULL(tuple_getitem);
127 (void)outputs_format.emplace_back(AnfAlgo::GetOutputFormat(
128 tuple_getitem->input(kIndex1), LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kIndex2))))));
129 (void)outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(
130 tuple_getitem->input(kIndex1), LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kIndex2))))));
131 } else {
132 (void)outputs_format.emplace_back(AnfAlgo::GetOutputFormat(output, 0));
133 (void)outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(output, 0));
134 }
135 }
136 builder.SetInputsFormat(inputs_format);
137 builder.SetInputsDeviceType(inputs_data_type);
138 builder.SetOutputsFormat(outputs_format);
139 builder.SetOutputsDeviceType(outputs_data_type);
140 builder.SetKernelType(KernelType::TBE_KERNEL);
141 return builder.Build();
142 }
143
CreateTupleGetItem(const AnfNodePtr & buffer_fusion_kernel,session::KernelGraph * kernel_graph,size_t output_index)144 AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph,
145 size_t output_index) {
146 MS_EXCEPTION_IF_NULL(kernel_graph);
147 std::vector<AnfNodePtr> tuple_getitem_inputs_list;
148 auto value = std::make_shared<ValueNode>(prim::kPrimTupleGetItem);
149 MS_EXCEPTION_IF_NULL(value);
150 auto idx = NewValueNode(SizeToLong(output_index));
151 MS_EXCEPTION_IF_NULL(idx);
152 int64_t temp = SizeToLong(output_index);
153 auto imm = std::make_shared<Int64Imm>(temp);
154 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
155 idx->set_abstract(abstract_scalar);
156 tuple_getitem_inputs_list.push_back(value);
157 tuple_getitem_inputs_list.push_back(buffer_fusion_kernel);
158 tuple_getitem_inputs_list.push_back(idx);
159 auto tuple_item = kernel_graph->NewCNode(tuple_getitem_inputs_list);
160 MS_EXCEPTION_IF_NULL(tuple_item);
161 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)},
162 {AnfAlgo::GetOutputInferShape(buffer_fusion_kernel, output_index)},
163 tuple_item.get());
164 return tuple_item;
165 }
166
ReplaceInputNodeInOtherFusionScope(std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos,int64_t fusion_id,const AnfNodePtr & output_item,const AnfNodePtr & replace_item)167 void ReplaceInputNodeInOtherFusionScope(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos,
168 int64_t fusion_id, const AnfNodePtr &output_item,
169 const AnfNodePtr &replace_item) {
170 for (int64_t id = fusion_id + 1; id <= SizeToLong(buffer_fusion_infos->size()); ++id) {
171 auto itr = std::find((*buffer_fusion_infos)[id].inputs_list.begin(), (*buffer_fusion_infos)[id].inputs_list.end(),
172 output_item);
173 if (itr != (*buffer_fusion_infos)[id].inputs_list.end()) {
174 MS_LOG(DEBUG) << "replace input of other pattern, id = " << id;
175 *itr = replace_item;
176 }
177 }
178 }
179
ReplaceOldNode(std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos,int64_t fusion_id,const AnfNodePtr & buffer_fusion_kernel,session::KernelGraph * kernel_graph)180 void ReplaceOldNode(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos, int64_t fusion_id,
181 const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) {
182 MS_EXCEPTION_IF_NULL(kernel_graph);
183 MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
184 auto manager = kernel_graph->manager();
185 MS_EXCEPTION_IF_NULL(manager);
186 auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
187 if (buffer_fusion_info.outputs_list.size() == 1) { // single output
188 if (kernel_graph != nullptr) {
189 kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
190 }
191 (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
192 ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0],
193 buffer_fusion_kernel);
194 } else { // multiple output
195 for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) {
196 auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index);
197 if (kernel_graph != nullptr) {
198 kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[index], tuple_item);
199 }
200 (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item);
201 ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index],
202 tuple_item);
203 }
204 }
205 }
206
GetFusionScopeComputeNodeList(session::KernelGraph * kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos)207 void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
208 std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
209 MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
210 MS_EXCEPTION_IF_NULL(kernel_graph);
211 auto nodes = TopoSort(kernel_graph->get_return());
212 for (auto &node : nodes) {
213 MS_EXCEPTION_IF_NULL(node);
214 if (!node->isa<CNode>()) {
215 continue;
216 }
217 auto cnode = node->cast<CNodePtr>();
218 if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) {
219 auto fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kOpAttrFusionId);
220 (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode);
221 }
222 }
223 }
224
GetFusionScopeInputNodeList(const session::KernelGraph & kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos)225 void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
226 std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
227 MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
228 auto manager = kernel_graph.manager();
229 MS_EXCEPTION_IF_NULL(manager);
230
231 for (auto &buffer_fusion_info : *buffer_fusion_infos) {
232 auto fusion_id = buffer_fusion_info.first;
233 const auto &fusion_info = buffer_fusion_info.second;
234 for (const auto &node : fusion_info.anf_nodes) {
235 auto cnode = node->cast<CNodePtr>();
236 MS_EXCEPTION_IF_NULL(cnode);
237 for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) {
238 auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0);
239 if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) ==
240 fusion_info.anf_nodes.end()) {
241 if (!HasAbstractMonad(cnode->input(idx))) {
242 (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx));
243 }
244 }
245 }
246 }
247 }
248 }
249
TupleGetitemNodeCompare(const AnfNodePtr & node1,const AnfNodePtr & node2)250 bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) {
251 MS_EXCEPTION_IF_NULL(node1);
252 MS_EXCEPTION_IF_NULL(node2);
253 auto getitem1 = node1->cast<CNodePtr>();
254 auto getitem2 = node2->cast<CNodePtr>();
255 MS_EXCEPTION_IF_NULL(getitem1);
256 MS_EXCEPTION_IF_NULL(getitem2);
257 if (getitem1->size() < kTupleGetItemInputSize) {
258 MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
259 << getitem1->DebugString() << "]";
260 }
261 if (getitem2->size() < kTupleGetItemInputSize) {
262 MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
263 << getitem2->DebugString() << "]";
264 }
265 auto output_idx1 = GetValue<int64_t>(GetValueNode(getitem1->input(kIndex2)));
266 auto output_idx2 = GetValue<int64_t>(GetValueNode(getitem2->input(kIndex2)));
267 return output_idx1 < output_idx2;
268 }
269
RemoveNodeFromUpdateState(session::KernelGraph * kernel_graph,const AnfNodePtr & node,const AnfNodePtr & updatestate)270 AnfNodePtr RemoveNodeFromUpdateState(session::KernelGraph *kernel_graph, const AnfNodePtr &node,
271 const AnfNodePtr &updatestate) {
272 MS_EXCEPTION_IF_NULL(kernel_graph);
273 MS_EXCEPTION_IF_NULL(node);
274 MS_EXCEPTION_IF_NULL(updatestate);
275 auto updatestate_cnode = updatestate->cast<CNodePtr>();
276 auto inputs = updatestate_cnode->inputs();
277 std::vector<AnfNodePtr> new_inputs;
278 (void)std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
279 [node](const AnfNodePtr &input) { return node != input; });
280 auto new_updatestate = kernel_graph->NewCNode(new_inputs);
281 new_updatestate->set_scope(updatestate->scope());
282 new_updatestate->set_abstract(updatestate->abstract());
283 return new_updatestate;
284 }
285
GetFusionScopeOutputNodeList(session::KernelGraph * kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos)286 void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
287 std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
288 MS_EXCEPTION_IF_NULL(kernel_graph);
289 MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
290 auto manager = kernel_graph->manager();
291 MS_EXCEPTION_IF_NULL(manager);
292
293 for (auto &buffer_fusion_info : *buffer_fusion_infos) {
294 auto fusion_id = buffer_fusion_info.first;
295 const auto &fusion_info = buffer_fusion_info.second;
296 for (const auto &node : fusion_info.anf_nodes) {
297 if (AnfAlgo::GetOutputTensorNum(node) == 1) {
298 auto use_nodes = manager->node_users()[node];
299 for (auto use_node : use_nodes) {
300 // Do not think of updatestate as real output,
301 // Ensuring normal fusion requires eliminating the node of the updatestate
302 if (AnfAlgo::CheckPrimitiveType(use_node.first, prim::kPrimUpdateState)) {
303 auto new_updatestate = RemoveNodeFromUpdateState(kernel_graph, node, use_node.first);
304 (void)manager->Replace(use_node.first, new_updatestate);
305 continue;
306 }
307 if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), use_node.first) ==
308 fusion_info.anf_nodes.end()) {
309 (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(node);
310 break;
311 }
312 }
313 } else {
314 int64_t prev_idx = 0;
315 std::vector<AnfNodePtr> tuple_getitem_nodes;
316 auto users = manager->node_users()[node];
317 for (auto &user : users) {
318 if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) {
319 auto new_updatestate = RemoveNodeFromUpdateState(kernel_graph, node, user.first);
320 (void)manager->Replace(user.first, new_updatestate);
321 continue;
322 }
323 if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimTupleGetItem)) {
324 (void)tuple_getitem_nodes.emplace_back(user.first);
325 }
326 }
327 std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare);
328 for (auto &getitem : tuple_getitem_nodes) {
329 MS_EXCEPTION_IF_NULL(getitem);
330 auto getitem_ptr = getitem->cast<CNodePtr>();
331 MS_EXCEPTION_IF_NULL(getitem_ptr);
332 auto input2 = getitem_ptr->input(kIndex2);
333 auto output_idx = GetValue<int64_t>(GetValueNode(input2));
334 for (int64_t stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) {
335 auto stub_node = CreateTupleGetItem(node, kernel_graph, LongToSize(stub_idx));
336 (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node);
337 }
338 prev_idx = output_idx + 1;
339 for (auto &item_use_node : manager->node_users()[getitem]) {
340 if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) ==
341 fusion_info.anf_nodes.end()) {
342 (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem);
343 break;
344 }
345 }
346 }
347 }
348 }
349 }
350 }
351
SetOutputUsedNumAttr(const session::KernelGraph & kernel_graph,const std::unordered_map<int64_t,BufferFusionInfo_t> & buffer_fusion_infos)352 void SetOutputUsedNumAttr(const session::KernelGraph &kernel_graph,
353 const std::unordered_map<int64_t, BufferFusionInfo_t> &buffer_fusion_infos) {
354 for (auto &fusion_info : buffer_fusion_infos) {
355 auto &fusion_nodes = fusion_info.second.anf_nodes;
356 for (auto iter = fusion_nodes.begin(); iter != fusion_nodes.end() - 1; ++iter) {
357 auto node = *iter;
358 auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
359 AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), node);
360 }
361 }
362 }
363
SetFusionOpRefInfos(session::KernelGraph * kernel_graph,const std::vector<AnfNodePtr> & outputs_list,const AnfNodePtr & fusion_kernel)364 void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<AnfNodePtr> &outputs_list,
365 const AnfNodePtr &fusion_kernel) {
366 MS_EXCEPTION_IF_NULL(kernel_graph);
367 auto manager = kernel_graph->manager();
368 MS_EXCEPTION_IF_NULL(manager);
369 for (size_t idx = 0; idx < outputs_list.size(); ++idx) {
370 auto output = outputs_list[idx];
371 MS_EXCEPTION_IF_NULL(output);
372 if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
373 auto real_output = AnfAlgo::VisitKernel(output, 0);
374 auto output_cnode = output->cast<CNodePtr>();
375 MS_EXCEPTION_IF_NULL(output_cnode);
376 auto input2 = output_cnode->input(kIndex2);
377 auto output_idx = GetValue<int64_t>(GetValueNode(input2));
378 session::AnfWithOutIndex out_pair(real_output.first, output_idx);
379 if (kernel_graph->IsInRefOutputMap(out_pair)) {
380 auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
381 session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
382 kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
383 }
384 } else {
385 session::AnfWithOutIndex out_pair(output, 0);
386 if (kernel_graph->IsInRefOutputMap(out_pair)) {
387 auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
388 session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
389 kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
390 }
391 }
392 }
393 }
394
CheckCircle(const session::KernelGraph & kernel_graph,const BufferFusionInfo_t & fusion_info)395 bool CheckCircle(const session::KernelGraph &kernel_graph, const BufferFusionInfo_t &fusion_info) {
396 bool has_circle = false;
397 for (auto &inp : fusion_info.inputs_list) {
398 MS_EXCEPTION_IF_NULL(inp);
399 if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) {
400 continue;
401 }
402
403 if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
404 has_circle = true;
405 break;
406 }
407 }
408 return has_circle;
409 }
410
RemoveCircle(const session::KernelGraph & kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos)411 void RemoveCircle(const session::KernelGraph &kernel_graph,
412 std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
413 MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
414 std::vector<int64_t> fusion_ids;
415 for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) {
416 bool has_circle = CheckCircle(kernel_graph, fusion_info);
417 if (has_circle) {
418 (void)fusion_ids.emplace_back(fusion_id);
419 }
420 }
421
422 for (auto &fusion_id : fusion_ids) {
423 buffer_fusion_infos->erase(fusion_id);
424 }
425 }
426 } // namespace
427
GetBufferFusionInfo(session::KernelGraph * kernel_graph,std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos) const428 void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
429 std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
430 MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
431 MS_EXCEPTION_IF_NULL(kernel_graph);
432 GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
433 GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos);
434 GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
435 // Remove the fusion infos which will produce a circle if do fusion
436 RemoveCircle(*kernel_graph, buffer_fusion_infos);
437 SetOutputUsedNumAttr(*kernel_graph, *buffer_fusion_infos);
438
439 for (auto &buffer_fusion_info : *buffer_fusion_infos) {
440 buffer_fusion_info.second.kernel_build_info =
441 CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list);
442 // just for full_name_with_scope for every buffer_fusion_info.
443 auto fusion_node = CreateFusionOp(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list,
444 buffer_fusion_info.second.anf_nodes, kernel_graph);
445 MS_EXCEPTION_IF_NULL(fusion_node);
446 buffer_fusion_info.second.full_name = fusion_node->fullname_with_scope();
447 }
448 }
449
FuseBufferFusionPattern(session::KernelGraph * kernel_graph) const450 bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const {
451 MS_EXCEPTION_IF_NULL(kernel_graph);
452 bool change = false;
453 std::unordered_map<int64_t, BufferFusionInfo_t> buffer_fusion_infos;
454 GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos);
455
456 std::vector<mindspore::kernel::FusionScopeInfo> fusion_scope_infos;
457 std::transform(
458 buffer_fusion_infos.begin(), buffer_fusion_infos.end(), std::back_inserter(fusion_scope_infos),
459 [](const std::pair<int64_t, BufferFusionInfo_t> &buffer_fusion_info) -> mindspore::kernel::FusionScopeInfo {
460 return mindspore::kernel::FusionScopeInfo(
461 buffer_fusion_info.first, buffer_fusion_info.second.full_name, buffer_fusion_info.second.inputs_list,
462 buffer_fusion_info.second.anf_nodes, buffer_fusion_info.second.outputs_list);
463 });
464 std::map<int64_t, kernel::KernelModPtr> kernel_mods;
465 std::string old_build = common::GetEnv("MS_OLD_BUILD_PROCESS");
466 if (!old_build.empty()) {
467 kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos);
468 } else if (!fusion_scope_infos.empty()) {
469 auto &build_manager = kernel::ascend::AscendKernelCompileManager::GetInstance();
470 kernel_mods = build_manager.AscendFusionOpCompile(fusion_scope_infos);
471 build_manager.ResetOldTask();
472 }
473 std::set<int64_t> fusion_ids;
474 for (auto &buffer_fusion_info : buffer_fusion_infos) {
475 MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size()
476 << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size()
477 << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size();
478 fusion_ids.insert(buffer_fusion_info.first);
479 }
480 // Replace fusion op from return to head
481 for (auto &fusion_id : fusion_ids) {
482 // Get kernel mod when supporting tbe
483 if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) {
484 MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed";
485 continue;
486 }
487 if (CheckCircle(*kernel_graph, buffer_fusion_infos[fusion_id])) {
488 MS_LOG(DEBUG) << "fusion id: " << fusion_id << " will cause graph circle, pass this fusion.";
489 } else {
490 change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph);
491 }
492 }
493 MS_LOG(DEBUG) << "End Buffer Fusion";
494 return change;
495 }
496
ReplaceFusionOp(std::unordered_map<int64_t,BufferFusionInfo_t> * buffer_fusion_infos,int64_t fusion_id,const kernel::KernelModPtr & kernel_ptr,session::KernelGraph * kernel_graph) const497 bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos,
498 int64_t fusion_id, const kernel::KernelModPtr &kernel_ptr,
499 session::KernelGraph *kernel_graph) const {
500 MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
501 auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
502 if (buffer_fusion_info.anf_nodes.size() < kFusionNodeNumThreshold) {
503 return false;
504 }
505 TraceGuard guard(std::make_shared<TraceOpt>(buffer_fusion_info.anf_nodes[0]->debug_info()));
506 auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list,
507 buffer_fusion_info.anf_nodes, kernel_graph);
508 buffer_fusion->set_fullname_with_scope(buffer_fusion_info.full_name);
509 AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get());
510 // Set abstract of fusion_op node
511 std::vector<TypeId> types;
512 std::vector<std::vector<size_t>> shapes;
513 for (const auto &out_node : buffer_fusion_info.outputs_list) {
514 size_t out_num = AnfAlgo::GetOutputTensorNum(out_node);
515 for (size_t idx = 0; idx < out_num; ++idx) {
516 (void)types.emplace_back(AnfAlgo::GetOutputInferDataType(out_node, idx));
517 (void)shapes.emplace_back(AnfAlgo::GetOutputInferShape(out_node, idx));
518 }
519 }
520 if (types.empty() || shapes.empty()) {
521 MS_LOG(WARNING) << "buffer_fusion_info.outputs_list is empty";
522 return false;
523 }
524 AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get());
525 AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get());
526 SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion);
527 ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph);
528 return true;
529 }
530
Run(const FuncGraphPtr & graph)531 bool UbPatternFusion::Run(const FuncGraphPtr &graph) {
532 bool changed = false;
533 MS_EXCEPTION_IF_NULL(graph);
534 auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
535 MS_EXCEPTION_IF_NULL(kernel_graph);
536 changed = FuseBufferFusionPattern(kernel_graph.get());
537 // clear fusion_id attr
538 for (auto &node : graph->nodes()) {
539 if (node != nullptr && node->isa<CNode>()) {
540 AnfAlgo::EraseNodeAttr(kAttrFusionId, node);
541 }
542 }
543 return changed;
544 }
545 } // namespace opt
546 } // namespace mindspore
547