1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/ascend/ir_fission/transdata_split.h"
17 #include "backend/optimizer/ascend/ascend_helper.h"
18 #include "backend/session/anf_runtime_algorithm.h"
19 #include "debug/anf_ir_dump.h"
20 #include "utils/trace_base.h"
21
22 namespace mindspore {
23 namespace opt {
24 const std::set<std::pair<string, string>> invalid_formats_pair = {
25 {kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, {kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
26 {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM},
27 {kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
28
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const29 const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
30 const EquivPtr &) const {
31 MS_EXCEPTION_IF_NULL(func_graph);
32 if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
33 CheckCNodeInputSize(node->cast<CNodePtr>(), kTransOpInputTensorNum);
34 if (IsFormatInvaild(node)) {
35 TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
36 return DoSplit(func_graph, node);
37 }
38 }
39 return nullptr;
40 }
41
IsFormatInvaild(const AnfNodePtr & node) const42 bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) const {
43 MS_EXCEPTION_IF_NULL(node);
44 auto cnode = node->cast<CNodePtr>();
45 MS_EXCEPTION_IF_NULL(cnode);
46 auto input_format = AnfAlgo::GetInputFormat(node, 0);
47 auto output_format = AnfAlgo::GetOutputFormat(node, 0);
48 auto format_pair = std::make_pair(input_format, output_format);
49
50 return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end();
51 }
52
DefinePattern() const53 const BaseRef TransDataSplit::DefinePattern() const {
54 VarPtr X = std::make_shared<Var>();
55 return VectorRef({prim::kPrimTransData, X});
56 }
57
58 // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW)
DoSplit(const FuncGraphPtr & func_graph,const AnfNodePtr & node) const59 CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
60 MS_EXCEPTION_IF_NULL(func_graph);
61 MS_EXCEPTION_IF_NULL(node);
62 auto cnode = node->cast<CNodePtr>();
63 MS_EXCEPTION_IF_NULL(cnode);
64 auto input_node = cnode->input(kIndex1);
65 MS_EXCEPTION_IF_NULL(input_node);
66
67 auto input_format = AnfAlgo::GetInputFormat(node, 0);
68 auto output_format = AnfAlgo::GetOutputFormat(node, 0);
69 CNodePtr new_transdata_node = nullptr;
70 CNodePtr new_transpose_node = nullptr;
71 CNodePtr new_replace_node = nullptr;
72 auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
73 // if output_format=default transdata need split transdata->transpose else transpose->transdata
74 if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
75 // trans input_format to hwcn
76 new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
77 false, prim::kPrimTransData->name());
78 RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node, padding_axis);
79 // trans hwcn to default_format
80 new_transpose_node = NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false,
81 prim::kPrimTranspose->name(), std::vector<int64_t>{3, 2, 0, 1});
82 RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node);
83 new_replace_node = new_transpose_node;
84 } else {
85 // trans default to hwcn
86 new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
87 false, prim::kPrimTranspose->name(), std::vector<int64_t>{2, 3, 1, 0});
88 if (output_format == kOpFormat_FRACTAL_ZN_LSTM) {
89 AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node);
90 }
91 RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node);
92
93 // trans hwcn to output_format
94 new_transdata_node =
95 NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::kPrimTransData->name());
96 RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node, padding_axis);
97 new_transdata_node->set_abstract(node->abstract());
98 new_replace_node = new_transdata_node;
99 }
100 MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success.";
101 return new_replace_node;
102 }
103 } // namespace opt
104 } // namespace mindspore
105