• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/core/parallel_matmul_concatenate.h"
18 #include "base/base.h"
19 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
20 
21 namespace mindspore::graphkernel {
22 namespace {
GetMatMulTransposeAttr(const CNodePtr & matmul)23 MMAttr GetMatMulTransposeAttr(const CNodePtr &matmul) {
24   auto mm_attrs = common::AnfAlgo::GetCNodePrimitive(matmul)->attrs();
25   if (mm_attrs.count(kTransposeA) == 0 || mm_attrs.count(kTransposeB) == 0) {
26     MS_LOG(WARNING) << "Can not find attr 'transpose_a' or 'transpose_b' in node " << matmul->fullname_with_scope();
27     return std::make_pair(false, false);
28   }
29   auto trans_a = GetValue<bool>(mm_attrs[kTransposeA]);
30   auto trans_b = GetValue<bool>(mm_attrs[kTransposeB]);
31   return std::make_pair(trans_a, trans_b);
32 }
33 
NewMatMulNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & matmul_inputs,const CNodePtr & orig_matmul,ShapeVector new_out_shape)34 CNodePtr NewMatMulNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs, const CNodePtr &orig_matmul,
35                        ShapeVector new_out_shape) {
36   auto matmul = func_graph->NewCNode(matmul_inputs);
37   func_graph->AddNode(matmul);
38   MS_EXCEPTION_IF_NULL(matmul);
39   MS_EXCEPTION_IF_NULL(matmul_inputs[1]);
40   auto orig_cnode = matmul_inputs[1]->cast<CNodePtr>();
41   if (orig_cnode != nullptr && orig_cnode->HasAttr(kOutputsFormat)) {
42     auto input_format = GetValue<std::vector<std::string>>(orig_cnode->GetAttr(kOutputsFormat))[0];
43     std::vector<std::string> outputs_formats(AnfUtils::GetOutputTensorNum(matmul), input_format);
44     matmul->AddAttr(kOutputsFormat, MakeValue(outputs_formats));
45   }
46   auto [trans_a, trans_b] = GetMatMulTransposeAttr(orig_matmul);
47   matmul->AddAttr(kTransposeA, MakeValue(trans_a));
48   matmul->AddAttr(kTransposeB, MakeValue(trans_b));
49   std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(matmul_inputs[1], 0)};
50   std::vector<ShapeVector> shapes = {new_out_shape};
51   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, matmul.get());
52   matmul->set_kernel_info(std::make_shared<device::KernelInfo>());
53   return matmul;
54 }
55 
GetBatchMNK(const CNodePtr & matmul)56 BMNK GetBatchMNK(const CNodePtr &matmul) {
57   int64_t b = 0;
58   int64_t m = 0;
59   int64_t n = 0;
60   int64_t k = 0;
61   auto shape_a = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex0);
62   auto shape_b = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex1);
63   auto [trans_a, trans_b] = GetMatMulTransposeAttr(matmul);
64   if (shape_a.size() == kDim3 && shape_b.size() == kDim3 && shape_a[kIndex0] == shape_b[kIndex0]) {
65     b = shape_a[kIndex0];
66     (void)shape_a.erase(shape_a.begin());
67     (void)shape_b.erase(shape_b.begin());
68   } else {
69     b = 1;
70   }
71   m = trans_a ? shape_a[kIndex1] : shape_a[kIndex0];
72   k = trans_a ? shape_a[kIndex0] : shape_a[kIndex1];
73   n = trans_b ? shape_b[kIndex0] : shape_b[kIndex1];
74   return std::tuple(b, m, n, k);
75 }
76 }  // namespace
77 
Analyse(const Group & branches) const78 ConcatenatePlan ParallelMatMulConcatenater::Analyse(const Group &branches) const {
79   ConcatenatePlan target_op_res;
80   Branch b0 = branches[kIndex0];
81   AnfNodePtr shared_input = b0.GetRootData();
82   target_op_res.in_shape = Callback::Instance()->GetOutputInferShape(shared_input, kIndex0);
83   auto matmul = b0.GetTargetOp()->cast<CNodePtr>();
84   MS_EXCEPTION_IF_NULL(matmul);
85   bool is_a_shared = false;
86   for (size_t i = 1; i < matmul->size(); ++i) {
87     auto in = matmul->input(i);
88     if (in == shared_input) {
89       is_a_shared = i == kIndex1;
90       break;
91     }
92   }
93 
94   auto [trans_a, trans_b] = GetMatMulTransposeAttr(matmul);
95   int64_t b = 0;
96   int64_t m = 0;
97   int64_t n = 0;
98   int64_t k = 0;
99   std::tie(b, m, n, k) = GetBatchMNK(matmul);
100   if (is_a_shared) {
101     auto shape_b = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex1);
102     size_t rank_b = shape_b.size();
103     auto n_idx = trans_b ? rank_b - kIndex2 : rank_b - kIndex1;
104     target_op_res.concat_in_idx = SizeToInt(n_idx);
105     target_op_res.split_out_idx = SizeToInt(rank_b - kIndex1);
106     int64_t new_n = n * SizeToLong(branches.size());
107     if (rank_b == kDim3) {
108       target_op_res.out_shape = ShapeVector({b, m, new_n});
109     } else {
110       target_op_res.out_shape = ShapeVector({m, new_n});
111     }
112   } else {
113     auto shape_a = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex0);
114     size_t rank_a = shape_a.size();
115     auto m_idx = trans_a ? rank_a - kIndex1 : rank_a - kIndex2;
116     target_op_res.concat_in_idx = SizeToInt(m_idx);
117     target_op_res.split_out_idx = SizeToInt(rank_a - kIndex2);
118     auto new_m = m * SizeToLong(branches.size());
119     if (rank_a == kDim3) {
120       target_op_res.out_shape = ShapeVector({b, new_m, n});
121     } else {
122       target_op_res.out_shape = ShapeVector({new_m, n});
123     }
124   }
125   return target_op_res;
126 }
127 
CanOpsBeCombined(const AnfNodePtr a,const AnfNodePtr b)128 bool ParallelMatMulConcatenater::CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) {
129   auto matmul1 = a->cast<CNodePtr>();
130   auto matmul2 = b->cast<CNodePtr>();
131   MS_EXCEPTION_IF_NULL(matmul1);
132   MS_EXCEPTION_IF_NULL(matmul2);
133   auto [trans_a1, trans_b1] = GetMatMulTransposeAttr(matmul1);
134   auto [trans_a2, trans_b2] = GetMatMulTransposeAttr(matmul2);
135   return trans_a1 == trans_a2 && trans_b1 == trans_b2;
136 }
137 
IsSupportedOp(const AnfNodePtr n)138 bool ParallelMatMulConcatenater::IsSupportedOp(const AnfNodePtr n) {
139   if (n == nullptr || n->cast<CNodePtr>() == nullptr) {
140     return false;
141   }
142   auto prim = GetCNodePrimitive(n);
143   if (prim == nullptr || unsupported_ops_.count(prim->name())) {
144     return false;
145   }
146   return true;
147 }
148 
MakeCombinedOp(const Group & branches)149 AnfNodePtr ParallelMatMulConcatenater::MakeCombinedOp(const Group &branches) {
150   Branch b1 = branches[0];
151   AnfNodePtr shared_input = b1.GetRootData();
152   auto matmul_op = b1.GetTargetOp()->cast<CNodePtr>();
153   MS_EXCEPTION_IF_NULL(matmul_op);
154   auto plan = Analyse(branches);
155   plans_.push_back(plan);
156   auto overall_inputs = ReloadInputs(branches, b1.target_op_pos, shared_input);
157   auto matmul = NewMatMulNode(main_graph_, overall_inputs, matmul_op, plan.out_shape);
158   MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(matmul), "AutoUpdateInfo fail");
159   return matmul;
160 }
161 
IsArgCompatible(const AnfNodePtr a,const AnfNodePtr b)162 bool ParallelMatMulConcatenater::IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) { return true; }
163 
ConcatParallelMatMul(AnfNodePtr root,uint64_t min_num_branches,const std::string & layout,const FuncGraphPtr & func_graph)164 AnfNodePtr ConcatParallelMatMul(AnfNodePtr root, uint64_t min_num_branches, const std::string &layout,
165                                 const FuncGraphPtr &func_graph) {
166   if (layout == kOpFormat_NCHW) {
167     auto res = ParallelMatMulConcatenater(min_num_branches, layout).Combine(root, func_graph);
168     return res;
169   }
170   MS_LOG(WARNING) << "Not supported combine for layout " << layout;
171   return root;
172 }
173 }  // namespace mindspore::graphkernel
174