1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pass/split_matmul_comm_elementwise_fp.h"
18
19 #include <memory>
20
21 #include "mindspore/core/ops/other_ops.h"
22 #include "mindspore/core/ops/array_ops.h"
23 #include "mindspore/core/ops/math_ops.h"
24 #include "frontend/parallel/step_parallel.h"
25 #include "frontend/parallel/graph_util/graph_info.h"
26 #include "include/common/utils/utils.h"
27 #include "ir/pattern_matcher.h"
28
29 namespace mindspore {
30 namespace parallel {
31 namespace {
32 constexpr int64_t kInt64Num0 = 0;
33 constexpr int64_t kInt64Num1 = 1;
34 constexpr int64_t kInt64Num2 = 2;
35 } // namespace
36
IsForwardCNode(const CNodePtr & cnode)37 static bool IsForwardCNode(const CNodePtr &cnode) {
38 MS_EXCEPTION_IF_NULL(cnode);
39 return !(cnode->HasPrimalAttr(kPrimalAttrForwardUniqueId) || cnode->HasAttr(kAttrDuplicated));
40 }
41
42 // MatMul -> AllReduce -> Add
PatternFilter(const AnfNodePtr & node)43 static bool PatternFilter(const AnfNodePtr &node) {
44 auto cnode = node->cast<CNodePtr>();
45 if (cnode == nullptr || !IsForwardCNode(cnode)) {
46 return true;
47 }
48
49 static PrimitiveSet expect_prim_type = {prim::kPrimAllReduce};
50 if (!IsOneOfPrimitiveCNode(cnode, expect_prim_type)) {
51 return true;
52 }
53 auto input_node = cnode->input(kIndex1);
54 if (input_node == nullptr || !IsPrimitiveCNode(input_node, prim::kPrimMatMul)) {
55 return true;
56 }
57 const auto &input_node_set = cnode->func_graph()->manager()->node_users()[input_node];
58 if (input_node_set.size() != kSizeOne) {
59 return true;
60 }
61 const auto &output_node_set = cnode->func_graph()->manager()->node_users()[cnode];
62 if (output_node_set.size() != kSizeOne) {
63 return true;
64 }
65
66 auto output_node = output_node_set.front().first;
67 auto index = output_node_set.front().second;
68 if (!IsPrimitiveCNode(output_node, prim::kPrimAdd) || index != kIndex1) {
69 return true;
70 }
71 return false;
72 }
73
CopyAllAttrs(const CNodePtr & dst_cnode,const CNodePtr & src_cnode)74 static void CopyAllAttrs(const CNodePtr &dst_cnode, const CNodePtr &src_cnode) {
75 MS_EXCEPTION_IF_NULL(dst_cnode);
76 MS_EXCEPTION_IF_NULL(src_cnode);
77 dst_cnode->set_attrs(src_cnode->attrs());
78 auto dst_prim_node = GetCNodePrimitive(dst_cnode);
79 auto src_prim_node = GetCNodePrimitive(src_cnode);
80 auto src_attrs = src_prim_node->attrs();
81 for (const auto &attr : src_attrs) {
82 dst_prim_node->set_attr(attr.first, attr.second);
83 }
84 }
85
SplitIntoInterleaved(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const AnfNodePtr & comm_node)86 static void SplitIntoInterleaved(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
87 const AnfNodePtr &comm_node) {
88 auto comm_cnode = comm_node->cast<CNodePtr>();
89 auto matmul_cnode = comm_cnode->input(kIndex1)->cast<CNodePtr>();
90 auto add_cnode = manager->node_users()[comm_cnode].front().first->cast<CNodePtr>();
91
92 bool transpose_a = GetValue<bool>(GetCNodePrimitive(matmul_cnode)->GetAttr("transpose_a"));
93 bool transpose_b = GetValue<bool>(GetCNodePrimitive(matmul_cnode)->GetAttr("transpose_b"));
94 const int64_t axis_a_0 = transpose_a ? 1 : 0;
95 const int64_t axis_b_1 = transpose_b ? 0 : 1;
96
97 auto comm_primtive = GetCNodePrimitive(comm_cnode);
98 auto matmul_input1 = matmul_cnode->input(kIndex1);
99 auto matmul_input1_shape = BaseShapeToShape(AnfAlgo::GetOutputDetailShape(matmul_input1, 0));
100 if (matmul_input1_shape[axis_a_0] % kInt64Num2 != 0) {
101 return;
102 }
103 auto matmul_input2 = matmul_cnode->input(kIndex2);
104 auto matmul_input2_shape = BaseShapeToShape(AnfAlgo::GetOutputDetailShape(matmul_input2, 0));
105 auto add_input2 = add_cnode->input(kIndex2);
106
107 // Create const value
108 auto value0 = NewValueNode(MakeValue(kInt64Num0));
109 auto value1 = NewValueNode(MakeValue(kInt64Num1));
110
111 // New split(matmul_input1, axis_a_0, 2)
112 auto split_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimSplit->Clone()), matmul_input1,
113 NewValueNode<int64_t>(axis_a_0), NewValueNode<int64_t>(kInt64Num2)});
114 int64_t slice_size = matmul_input1_shape[axis_a_0] / kInt64Num2;
115 AddCNodePrimAttr(split_cnode, kAttrSizeSplits, MakeValue(ShapeVector{slice_size, slice_size}));
116 AddCNodePrimAttr(split_cnode, kAttrNumSplit, MakeValue(kInt64Num2));
117
118 // branch_a: split_cnode->TupleGetItem(0)->Matmul_a->AllReduce_a->Add_a
119 auto tuple_get_item_a = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem->Clone()), split_cnode, value0});
120 auto matmul_a = func_graph->NewCNode({NewValueNode(prim::kPrimMatMul->Clone()), tuple_get_item_a, matmul_input2});
121 CopyAllAttrs(matmul_a, matmul_cnode);
122 auto comm_a = func_graph->NewCNode({NewValueNode(comm_primtive->Clone()), matmul_a});
123 CopyAllAttrs(comm_a, comm_cnode);
124 CNodePtr add_a = func_graph->NewCNode({NewValueNode(prim::kPrimAdd->Clone()), comm_a, add_input2});
125
126 // branch_b: split_cnode->TupleGetItem(1)->Matmul_b->AllReduce_b(depend AllReduce_a) ->Add_b
127 auto tuple_get_item_b = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem->Clone()), split_cnode, value1});
128 auto matmul_b = func_graph->NewCNode({NewValueNode(prim::kPrimMatMul->Clone()), tuple_get_item_b, matmul_input2});
129 CopyAllAttrs(matmul_b, matmul_cnode);
130 auto depend = func_graph->NewCNode({NewValueNode(prim::kPrimDepend->Clone()), matmul_b, comm_a});
131 auto comm_b = func_graph->NewCNode({NewValueNode(comm_primtive->Clone()), depend});
132 CopyAllAttrs(comm_b, comm_cnode);
133 auto add_b = func_graph->NewCNode({NewValueNode(prim::kPrimAdd->Clone()), comm_b, add_input2});
134
135 // New concat(MakeTuple(add_a, add_b))
136 auto make_tuple_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple->Clone()), add_a, add_b});
137 auto concat_cnode = func_graph->NewCNode(
138 {NewValueNode(prim::kPrimConcat->Clone()), make_tuple_cnode, NewValueNode(MakeValue(axis_a_0))});
139
140 // Infer path_a abstract
141 auto dtype = common::AnfAlgo::GetOutputInferDataType(matmul_input1, 0);
142 ShapeVector split_single_shape = matmul_input1_shape;
143 split_single_shape[axis_a_0] /= kInt64Num2;
144 auto split_shape_abstract = std::make_shared<abstract::Shape>(split_single_shape);
145 common::AnfAlgo::SetOutputTypeAndDetailShape({dtype, dtype}, {split_shape_abstract, split_shape_abstract},
146 split_cnode.get());
147 common::AnfAlgo::SetOutputTypeAndDetailShape({dtype}, {split_shape_abstract}, tuple_get_item_a.get());
148 ShapeVector matmul_ab_shape{split_single_shape[axis_a_0], matmul_input2_shape[axis_b_1]};
149 auto matmul_ab_abstract = std::make_shared<abstract::Shape>(matmul_ab_shape);
150 common::AnfAlgo::SetOutputTypeAndDetailShape({dtype}, {matmul_ab_abstract}, matmul_a.get());
151 comm_a->set_abstract(matmul_a->abstract());
152 add_a->set_abstract(matmul_a->abstract());
153
154 // set path_b from path_a
155 tuple_get_item_b->set_abstract(tuple_get_item_a->abstract());
156 matmul_b->set_abstract(matmul_a->abstract());
157 comm_b->set_abstract(matmul_b->abstract());
158 depend->set_abstract(comm_b->abstract());
159 add_b->set_abstract(add_a->abstract());
160
161 // set abstract for make_tuple and concat
162 common::AnfAlgo::SetOutputTypeAndDetailShape({dtype, dtype}, {matmul_ab_abstract, matmul_ab_abstract},
163 make_tuple_cnode.get());
164 concat_cnode->set_abstract(add_cnode->abstract());
165
166 // Replace graph
167 auto prev_cnode = matmul_cnode->input(kIndex1);
168 manager->SetEdge(split_cnode, kIndex1, prev_cnode);
169 auto next_cnode_users = manager->node_users()[add_cnode];
170 for (const auto &next_cnode_pair : next_cnode_users) {
171 manager->SetEdge(next_cnode_pair.first, next_cnode_pair.second, concat_cnode);
172 }
173 }
174
175 // From:
176 // MatMul -> AllReduce -> Add
177 // To:
178 // --> TupleGetItem(0) -> MatMul_a -> AllReduce_a -> Add_a
179 // Split -> Concat
180 // --> TupleGetItem(1) -> MatMul_b -> Depend(AllReduce_a) -> AllReduce_b -> Add_b
SplitMatmulCommElementwiseFp(const FuncGraphPtr & func_graph)181 void SplitMatmulCommElementwiseFp(const FuncGraphPtr &func_graph) {
182 if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
183 parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
184 MS_LOG(INFO) << "SplitMatmulCommElementwiseFp is only support under [semi_]auto_parallel, skip it.";
185 return;
186 }
187
188 auto ms_context = MsContext::GetInstance();
189 MS_EXCEPTION_IF_NULL(ms_context);
190 auto is_enable = ms_context->get_param<bool>(MS_CTX_INTERLEAVED_MATMUL_COMM);
191 if (!is_enable) {
192 return;
193 }
194
195 MS_EXCEPTION_IF_NULL(func_graph);
196 auto manager = func_graph->manager();
197 MS_EXCEPTION_IF_NULL(manager);
198 auto todo = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, PatternFilter);
199 for (const auto &node : todo) {
200 SplitIntoInterleaved(func_graph, manager, node);
201 }
202 }
203 } // namespace parallel
204 } // namespace mindspore
205