• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/ascend/ir_fission/transdata_split.h"
17 #include "backend/optimizer/ascend/ascend_helper.h"
18 #include "backend/session/anf_runtime_algorithm.h"
19 #include "debug/anf_ir_dump.h"
20 #include "utils/trace_base.h"
21 
22 namespace mindspore {
23 namespace opt {
24 const std::set<std::pair<string, string>> invalid_formats_pair = {
25   {kOpFormat_C1HWNCoC0, kOpFormat_NCHW},          {kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
26   {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT},       {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM},
27   {kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
28 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const29 const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
30                                          const EquivPtr &) const {
31   MS_EXCEPTION_IF_NULL(func_graph);
32   if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
33     CheckCNodeInputSize(node->cast<CNodePtr>(), kTransOpInputTensorNum);
34     if (IsFormatInvaild(node)) {
35       TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
36       return DoSplit(func_graph, node);
37     }
38   }
39   return nullptr;
40 }
41 
IsFormatInvaild(const AnfNodePtr & node) const42 bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) const {
43   MS_EXCEPTION_IF_NULL(node);
44   auto cnode = node->cast<CNodePtr>();
45   MS_EXCEPTION_IF_NULL(cnode);
46   auto input_format = AnfAlgo::GetInputFormat(node, 0);
47   auto output_format = AnfAlgo::GetOutputFormat(node, 0);
48   auto format_pair = std::make_pair(input_format, output_format);
49 
50   return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end();
51 }
52 
DefinePattern() const53 const BaseRef TransDataSplit::DefinePattern() const {
54   VarPtr X = std::make_shared<Var>();
55   return VectorRef({prim::kPrimTransData, X});
56 }
57 
58 // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW)
DoSplit(const FuncGraphPtr & func_graph,const AnfNodePtr & node) const59 CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
60   MS_EXCEPTION_IF_NULL(func_graph);
61   MS_EXCEPTION_IF_NULL(node);
62   auto cnode = node->cast<CNodePtr>();
63   MS_EXCEPTION_IF_NULL(cnode);
64   auto input_node = cnode->input(kIndex1);
65   MS_EXCEPTION_IF_NULL(input_node);
66 
67   auto input_format = AnfAlgo::GetInputFormat(node, 0);
68   auto output_format = AnfAlgo::GetOutputFormat(node, 0);
69   CNodePtr new_transdata_node = nullptr;
70   CNodePtr new_transpose_node = nullptr;
71   CNodePtr new_replace_node = nullptr;
72   auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
73   // if output_format=default transdata need split transdata->transpose else transpose->transdata
74   if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
75     // trans input_format to hwcn
76     new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
77                                         false, prim::kPrimTransData->name());
78     RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node, padding_axis);
79     // trans hwcn to default_format
80     new_transpose_node = NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false,
81                                         prim::kPrimTranspose->name(), std::vector<int64_t>{3, 2, 0, 1});
82     RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node);
83     new_replace_node = new_transpose_node;
84   } else {
85     // trans default to hwcn
86     new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
87                                         false, prim::kPrimTranspose->name(), std::vector<int64_t>{2, 3, 1, 0});
88     if (output_format == kOpFormat_FRACTAL_ZN_LSTM) {
89       AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node);
90     }
91     RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node);
92 
93     // trans hwcn to output_format
94     new_transdata_node =
95       NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::kPrimTransData->name());
96     RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node, padding_axis);
97     new_transdata_node->set_abstract(node->abstract());
98     new_replace_node = new_transdata_node;
99   }
100   MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success.";
101   return new_replace_node;
102 }
103 }  // namespace opt
104 }  // namespace mindspore
105