1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <utility>
18 #include <vector>
19 #include <memory>
20
21 #include "mindspore/core/ops/other_ops.h"
22 #include "mindspore/core/ops/array_ops.h"
23 #include "mindspore/core/ops/math_ops.h"
24 #include "mindspore/core/ops/nn_optimizer_ops.h"
25 #include "mindspore/core/ops/nn_ops.h"
26 #include "frontend/parallel/pass/split_layernorm_comm_fp.h"
27 #include "frontend/parallel/step_parallel.h"
28 #include "include/common/utils/utils.h"
29 #include "ir/pattern_matcher.h"
30
31 namespace mindspore {
32 namespace parallel {
33 namespace {
34 constexpr int64_t kLongZero = 0;
35 constexpr int64_t kLongOne = 1;
36 constexpr int64_t kLongTwo = 2;
37 using PrimitiveIndex = std::pair<PrimitivePtr, int>;
38
IsAnyMatMulInputTranspose(const CNodePtr & matmul_cnode)39 bool IsAnyMatMulInputTranspose(const CNodePtr &matmul_cnode) {
40 MS_EXCEPTION_IF_NULL(matmul_cnode);
41 if (!IsPrimitiveCNode(matmul_cnode)) {
42 return false;
43 }
44 return GetValue<bool>(GetCNodePrimitive(matmul_cnode)->GetAttr("transpose_a")) ||
45 GetValue<bool>(GetCNodePrimitive(matmul_cnode)->GetAttr("transpose_b"));
46 }
47
CopyAllAttrs(const CNodePtr & dst_cnode,const CNodePtr & src_cnode)48 void CopyAllAttrs(const CNodePtr &dst_cnode, const CNodePtr &src_cnode) {
49 MS_EXCEPTION_IF_NULL(dst_cnode);
50 MS_EXCEPTION_IF_NULL(src_cnode);
51 dst_cnode->set_attrs(src_cnode->attrs());
52 auto dst_prim_node = GetCNodePrimitive(dst_cnode);
53 auto src_prim_node = GetCNodePrimitive(src_cnode);
54 auto src_attrs = src_prim_node->attrs();
55 for (const auto &attr : src_attrs) {
56 dst_prim_node->set_attr(attr.first, attr.second);
57 }
58 }
59
GetSliceShape(const ShapeVector & origin_shape,size_t slice_axis=0,int64_t slice_size=2)60 ShapeVector GetSliceShape(const ShapeVector &origin_shape, size_t slice_axis = 0, int64_t slice_size = 2) {
61 if (slice_axis >= origin_shape.size()) {
62 MS_LOG(EXCEPTION) << "The slice_axis must be less than origin_shape.size(), but got " << slice_axis << " and "
63 << origin_shape.size();
64 }
65 if (slice_size == 0) {
66 MS_LOG(EXCEPTION) << "The input 'slice_size' must be a positive integer, but got " << slice_size;
67 }
68 if (origin_shape[slice_axis] % slice_size != 0) {
69 MS_LOG(EXCEPTION) << "The slice_size must be divisible int origin_shape[" << slice_axis << "], but got "
70 << slice_size << " and " << origin_shape[slice_axis];
71 }
72 auto slice_shape = origin_shape;
73 slice_shape[slice_axis] /= slice_size;
74 return slice_shape;
75 }
76
77 // Only for single outputs cnode
NewCNodeAndCloneAttrsSetSliceAbstract(const FuncGraphPtr & func_graph,const CNodePtr & src_cnode,std::vector<AnfNodePtr> && inputs,size_t slice_axis=0)78 CNodePtr NewCNodeAndCloneAttrsSetSliceAbstract(const FuncGraphPtr &func_graph, const CNodePtr &src_cnode,
79 std::vector<AnfNodePtr> &&inputs, size_t slice_axis = 0) {
80 auto cnode = func_graph->NewCNode(inputs);
81 CopyAllAttrs(cnode, src_cnode);
82
83 // set abstract
84 ShapeVector src_cnode_shape = common::AnfAlgo::GetOutputInferShape(src_cnode, slice_axis);
85 ShapeVector slice_shape = GetSliceShape(src_cnode_shape, slice_axis);
86 common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetOutputInferDataType(src_cnode, slice_axis)},
87 {std::make_shared<abstract::Shape>(slice_shape)}, cnode.get());
88 return cnode;
89 }
90
NewTupleGetItemCNodeAndSetAbstract(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const int64_t index)91 CNodePtr NewTupleGetItemCNodeAndSetAbstract(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
92 const int64_t index) {
93 MS_EXCEPTION_IF_NULL(input_node);
94 auto tuple_get_item_cnode =
95 func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem->Clone()), input_node, NewValueNode(MakeValue(index))});
96
97 auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, LongToSize(index));
98 auto slice_shape = common::AnfAlgo::GetOutputInferShape(input_node, LongToSize(index));
99 auto slice_shape_abstract = std::make_shared<abstract::Shape>(slice_shape);
100 common::AnfAlgo::SetOutputTypeAndDetailShape({dtype}, {slice_shape_abstract}, tuple_get_item_cnode.get());
101 return tuple_get_item_cnode;
102 }
103
NewLayerNormCNodeAndCloneAttrsSetSliceAbstract(const FuncGraphPtr & func_graph,const CNodePtr & src_cnode,std::vector<AnfNodePtr> && inputs)104 CNodePtr NewLayerNormCNodeAndCloneAttrsSetSliceAbstract(const FuncGraphPtr &func_graph, const CNodePtr &src_cnode,
105 std::vector<AnfNodePtr> &&inputs) {
106 MS_EXCEPTION_IF_NULL(src_cnode);
107 auto layernorm_cnode = func_graph->NewCNode(inputs);
108 CopyAllAttrs(layernorm_cnode, src_cnode);
109
110 // set abstract
111 auto slice_shape = GetSliceShape(common::AnfAlgo::GetOutputInferShape(src_cnode, kIndex0));
112 auto slice_shape_abstract = std::make_shared<abstract::Shape>(slice_shape);
113 common::AnfAlgo::SetOutputTypeAndDetailShape(
114 {common::AnfAlgo::GetOutputInferDataType(src_cnode, kIndex0),
115 common::AnfAlgo::GetOutputInferDataType(src_cnode, kIndex1),
116 common::AnfAlgo::GetOutputInferDataType(src_cnode, kIndex2)},
117 {slice_shape_abstract, std::make_shared<abstract::Shape>(ShapeVector{slice_shape[kIndex0], kLongOne}),
118 std::make_shared<abstract::Shape>(ShapeVector{slice_shape[kIndex0], kLongOne})},
119 layernorm_cnode.get());
120
121 return layernorm_cnode;
122 }
123
NewSplitCNodeAndSetAbstract(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,int64_t axis,int64_t output_num)124 CNodePtr NewSplitCNodeAndSetAbstract(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, int64_t axis,
125 int64_t output_num) {
126 MS_EXCEPTION_IF_NULL(input_node);
127 auto input_node_shape = BaseShapeToShape(AnfAlgo::GetOutputDetailShape(input_node, 0));
128 if (output_num == 0) {
129 MS_LOG(EXCEPTION) << "The input 'output_num' must be a positive integer, but got " << output_num;
130 }
131 if (LongToSize(axis) >= input_node_shape.size() || input_node_shape[axis] % output_num != 0) {
132 return nullptr;
133 }
134 auto split_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimSplit->Clone()), input_node,
135 NewValueNode<int64_t>(axis), NewValueNode<int64_t>(output_num)});
136 auto input_shape = common::AnfAlgo::GetOutputInferShape(input_node, kIndex0);
137 int64_t slice_size = input_shape[kIndex0] / kLongTwo;
138 AddCNodePrimAttr(split_cnode, kAttrSizeSplits, MakeValue(ShapeVector{slice_size, slice_size}));
139 AddCNodePrimAttr(split_cnode, kAttrNumSplit, MakeValue(output_num));
140
141 auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, kIndex0);
142 ShapeVector slice_shape = GetSliceShape(input_shape);
143 auto slice_shape_abstract = std::make_shared<abstract::Shape>(slice_shape);
144 common::AnfAlgo::SetOutputTypeAndDetailShape({dtype, dtype}, {slice_shape_abstract, slice_shape_abstract},
145 split_cnode.get());
146 return split_cnode;
147 }
148
NewConcatCNodeAndSetAbstract(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node0,const AnfNodePtr & input_node1,int64_t axis,int64_t input_num)149 CNodePtr NewConcatCNodeAndSetAbstract(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node0,
150 const AnfNodePtr &input_node1, int64_t axis, int64_t input_num) {
151 auto make_tuple_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple->Clone()), input_node0, input_node1});
152 auto concat_cnode =
153 func_graph->NewCNode({NewValueNode(prim::kPrimConcat->Clone()), make_tuple_cnode, NewValueNode(MakeValue(axis))});
154
155 auto input_node0_dtype = common::AnfAlgo::GetOutputInferDataType(input_node0, kIndex0);
156 auto input_node0_shape = common::AnfAlgo::GetOutputInferShape(input_node0, kIndex0);
157 auto input_node1_dtype = common::AnfAlgo::GetOutputInferDataType(input_node1, kIndex0);
158 auto input_node1_shape = common::AnfAlgo::GetOutputInferShape(input_node1, kIndex0);
159 common::AnfAlgo::SetOutputTypeAndDetailShape(
160 {input_node0_dtype, input_node1_dtype},
161 {std::make_shared<abstract::Shape>(input_node0_shape), std::make_shared<abstract::Shape>(input_node1_shape)},
162 make_tuple_cnode.get());
163 auto concat_shape = input_node0_shape;
164 concat_shape[axis] += input_node1_shape[axis];
165 common::AnfAlgo::SetOutputTypeAndDetailShape({input_node0_dtype}, {std::make_shared<abstract::Shape>(concat_shape)},
166 concat_cnode.get());
167 return concat_cnode;
168 }
169
InsertDependAndSetAbstract(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const AnfNodePtr & prior_node,const AnfNodePtr & post_node)170 void InsertDependAndSetAbstract(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
171 const AnfNodePtr &prior_node, const AnfNodePtr &post_node) {
172 MS_EXCEPTION_IF_NULL(prior_node);
173 MS_EXCEPTION_IF_NULL(post_node);
174 auto post_cnode = post_node->cast<CNodePtr>();
175 MS_EXCEPTION_IF_NULL(post_cnode);
176 std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), post_cnode->input(kIndex1), prior_node};
177 auto depend_cnode = func_graph->NewCNode(depend_inputs);
178 manager->SetEdge(post_node, kIndex1, depend_cnode);
179 depend_cnode->set_abstract(depend_cnode->input(kIndex1)->abstract());
180 }
181 } // namespace
182
IsForwardCNode(const CNodePtr & cnode)183 static bool IsForwardCNode(const CNodePtr &cnode) {
184 MS_EXCEPTION_IF_NULL(cnode);
185 return !(cnode->HasPrimalAttr(kPrimalAttrForwardUniqueId) || cnode->HasAttr(kAttrDuplicated));
186 }
187
IsCareNode(const AnfNodePtr & node)188 static bool IsCareNode(const AnfNodePtr &node) {
189 MS_EXCEPTION_IF_NULL(node);
190 auto cnode = node->cast<CNodePtr>();
191 MS_EXCEPTION_IF_NULL(cnode);
192 if (!IsOneOfPrimitiveCNode(cnode, {prim::kPrimCast, prim::kPrimGeLU, prim::kPrimFastGeLU})) {
193 return false;
194 }
195 const auto &node_user = cnode->func_graph()->manager()->node_users()[cnode];
196 return node_user.size() == kSizeOne;
197 }
198
199 // LayerNorm->TupleGetItem->Cast->AllGather->Matmul->Add->Activation->MatMul->ReduceScatter
PatternFilter(const AnfNodePtr & node)200 static bool PatternFilter(const AnfNodePtr &node) {
201 auto cnode = node->cast<CNodePtr>();
202 if (cnode == nullptr || !IsForwardCNode(cnode)) {
203 return true;
204 }
205 if (!IsPrimitiveCNode(cnode, prim::kPrimLayerNorm)) {
206 return true;
207 }
208
209 static std::vector<PrimitiveIndex> expect_primitive_list = {
210 {prim::kPrimTupleGetItem, 1}, {prim::kPrimCast, 1}, {prim::kPrimAllGather, 1}, {prim::kPrimMatMul, 1},
211 {prim::kPrimAdd, 1}, {prim::kPrimFastGeLU, 1}, {prim::kPrimMatMul, 1}, {prim::kPrimReduceScatter, 1}};
212 AnfNodePtr cur_node = node;
213 for (const auto &expect_prim : expect_primitive_list) {
214 auto cur_cnode = cur_node->cast<CNodePtr>();
215 auto output_node_set = cur_cnode->func_graph()->manager()->node_users()[cur_cnode];
216 if (output_node_set.size() != kSizeOne) {
217 return true;
218 }
219 auto index = output_node_set.front().second;
220 if (index != expect_prim.second) {
221 return true;
222 }
223 auto next_node = output_node_set.front().first;
224 auto next_cnode = next_node->cast<CNodePtr>();
225 if (next_cnode == nullptr || !IsForwardCNode(next_cnode)) {
226 return true;
227 }
228 if (!IsPrimitiveCNode(next_cnode, expect_prim.first)) {
229 return true;
230 }
231 cur_node = next_node;
232 }
233 return false;
234 }
235
ExpandSliceRangeToLeft(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const CNodePtr & split_cnode)236 static void ExpandSliceRangeToLeft(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
237 const CNodePtr &split_cnode) {
238 MS_EXCEPTION_IF_NULL(split_cnode);
239 auto pre_node = split_cnode->input(kIndex1);
240 std::shared_ptr<abstract::Shape> slice_pre_cnode_shape_abstract;
241 while (IsCareNode(pre_node)) {
242 auto pre_cnode = pre_node->cast<CNodePtr>();
243 MS_EXCEPTION_IF_NULL(pre_cnode);
244 auto pre_cnode_dtype = common::AnfAlgo::GetOutputInferDataType(pre_cnode, kIndex0);
245 auto pre_cnode_shape = common::AnfAlgo::GetOutputInferShape(pre_cnode, kIndex0);
246 auto slice_pre_cnode_shape = {pre_cnode_shape[kIndex0] / kLongTwo, pre_cnode_shape[kIndex1]};
247 slice_pre_cnode_shape_abstract = std::make_shared<abstract::Shape>(slice_pre_cnode_shape);
248 auto pre_node_prim = GetCNodePrimitive(pre_cnode);
249 auto node_users = manager->node_users()[split_cnode];
250 for (const auto &node_user : node_users) {
251 auto tuple_get_item_node = node_user.first;
252 auto tuple_get_item_node_users = manager->node_users()[tuple_get_item_node];
253 auto pre_cnode_sub = func_graph->NewCNode({NewValueNode(pre_node_prim->Clone()), tuple_get_item_node});
254 common::AnfAlgo::SetOutputTypeAndDetailShape({pre_cnode_dtype}, {slice_pre_cnode_shape_abstract},
255 pre_cnode_sub.get());
256 for (const auto &tuple_get_item_node_user : tuple_get_item_node_users) {
257 manager->SetEdge(tuple_get_item_node_user.first, tuple_get_item_node_user.second, pre_cnode_sub);
258 }
259 }
260
261 manager->SetEdge(split_cnode, kIndex1, pre_cnode->input(kIndex1));
262 pre_node = split_cnode->input(kIndex1);
263 }
264 if (slice_pre_cnode_shape_abstract == nullptr) {
265 return;
266 }
267 // Refresh abstract for split and tuple_get_item
268 auto new_split_cnode_dtype = common::AnfAlgo::GetOutputInferDataType(split_cnode->input(kIndex1), kIndex0);
269 common::AnfAlgo::SetOutputTypeAndDetailShape({new_split_cnode_dtype, new_split_cnode_dtype},
270 {slice_pre_cnode_shape_abstract, slice_pre_cnode_shape_abstract},
271 split_cnode.get());
272 auto node_users = manager->node_users()[split_cnode];
273 for (const auto &node_user : node_users) {
274 common::AnfAlgo::SetOutputTypeAndDetailShape({new_split_cnode_dtype}, {slice_pre_cnode_shape_abstract},
275 node_user.first.get());
276 }
277 }
278
ExpandSliceRangeToRight(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const CNodePtr & concat_cnode)279 static void ExpandSliceRangeToRight(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
280 const CNodePtr &concat_cnode) {
281 while (true) {
282 const auto &node_users = manager->node_users()[concat_cnode];
283 if (node_users.size() != kSizeOne) {
284 return;
285 }
286 auto next_cnode = node_users.front().first->cast<CNodePtr>();
287 MS_EXCEPTION_IF_NULL(next_cnode);
288 auto next_cnode_prim = GetCNodePrimitive(next_cnode);
289 if (!IsOneOfPrimitiveCNode(next_cnode, {prim::kPrimCast, prim::kPrimGeLU, prim::kPrimFastGeLU})) {
290 return;
291 }
292
293 auto make_tuple_cnode = concat_cnode->input(kIndex1)->cast<CNodePtr>();
294 MS_EXCEPTION_IF_NULL(make_tuple_cnode);
295
296 for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
297 auto input_node = make_tuple_cnode->input(i);
298 auto next_cnode_sub = func_graph->NewCNode({NewValueNode(next_cnode_prim->Clone()), input_node});
299 next_cnode_sub->set_abstract(input_node->abstract());
300 manager->SetEdge(make_tuple_cnode, SizeToInt(i), next_cnode_sub);
301 }
302 auto next_cnode_users = manager->node_users()[next_cnode];
303 for (const auto &next_cnode_pair : next_cnode_users) {
304 manager->SetEdge(next_cnode_pair.first, next_cnode_pair.second, concat_cnode);
305 }
306 }
307 }
308
SplitIntoInterleaved(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const AnfNodePtr & layernorm_node)309 static void SplitIntoInterleaved(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
310 const AnfNodePtr &layernorm_node) {
311 auto layernorm_cnode = layernorm_node->cast<CNodePtr>();
312 auto tuple_get_item_cnode = manager->node_users()[layernorm_cnode].front().first->cast<CNodePtr>();
313 auto cast_cnode = manager->node_users()[tuple_get_item_cnode].front().first->cast<CNodePtr>();
314 auto allgather_cnode = manager->node_users()[cast_cnode].front().first->cast<CNodePtr>();
315 auto matmul1_cnode = manager->node_users()[allgather_cnode].front().first->cast<CNodePtr>();
316 if (IsAnyMatMulInputTranspose(matmul1_cnode)) {
317 return;
318 }
319 auto add_cnode = manager->node_users()[matmul1_cnode].front().first->cast<CNodePtr>();
320 auto fast_gelu_cnode = manager->node_users()[add_cnode].front().first->cast<CNodePtr>();
321 auto matmul2_cnode = manager->node_users()[fast_gelu_cnode].front().first->cast<CNodePtr>();
322 if (IsAnyMatMulInputTranspose(matmul2_cnode)) {
323 return;
324 }
325 auto reduce_scatter_cnode = manager->node_users()[matmul2_cnode].front().first->cast<CNodePtr>();
326
327 // New split(layernorm_input1, 0, 2)
328 auto split_cnode = NewSplitCNodeAndSetAbstract(func_graph, layernorm_cnode->input(kIndex1), kLongZero, kLongTwo);
329 if (split_cnode == nullptr) {
330 return;
331 }
332
333 // branch_a: split_cnode->TupleGetItem(0)->LayerNorm->...->ReduceScatter->MakeTuple->Concat
334 auto get_slice_a = NewTupleGetItemCNodeAndSetAbstract(func_graph, split_cnode, 0);
335 auto layernorm_a =
336 NewLayerNormCNodeAndCloneAttrsSetSliceAbstract(func_graph, layernorm_cnode,
337 {NewValueNode(prim::kPrimLayerNorm->Clone()), get_slice_a,
338 layernorm_cnode->input(kIndex2), layernorm_cnode->input(kIndex3)});
339 auto tuple_get_item_a = NewCNodeAndCloneAttrsSetSliceAbstract(
340 func_graph, tuple_get_item_cnode,
341 {NewValueNode(prim::kPrimTupleGetItem->Clone()), layernorm_a, tuple_get_item_cnode->input(kIndex2)});
342 auto cast_a = NewCNodeAndCloneAttrsSetSliceAbstract(
343 func_graph, cast_cnode, {NewValueNode(prim::kPrimCast->Clone()), tuple_get_item_a, cast_cnode->input(kIndex2)});
344 auto allgather_a = NewCNodeAndCloneAttrsSetSliceAbstract(func_graph, allgather_cnode,
345 {NewValueNode(prim::kPrimAllGather->Clone()), cast_a});
346 auto matmul1_a = NewCNodeAndCloneAttrsSetSliceAbstract(
347 func_graph, matmul1_cnode, {NewValueNode(prim::kPrimMatMul->Clone()), allgather_a, matmul1_cnode->input(kIndex2)});
348 auto add_a = NewCNodeAndCloneAttrsSetSliceAbstract(
349 func_graph, add_cnode, {NewValueNode(prim::kPrimAdd->Clone()), matmul1_a, add_cnode->input(kIndex2)});
350 auto fast_gelu_a = NewCNodeAndCloneAttrsSetSliceAbstract(func_graph, fast_gelu_cnode,
351 {NewValueNode(prim::kPrimFastGeLU->Clone()), add_a});
352 auto matmul2_a = NewCNodeAndCloneAttrsSetSliceAbstract(
353 func_graph, matmul2_cnode, {NewValueNode(prim::kPrimMatMul->Clone()), fast_gelu_a, matmul2_cnode->input(kIndex2)});
354 auto reduce_scatter_a = NewCNodeAndCloneAttrsSetSliceAbstract(
355 func_graph, reduce_scatter_cnode, {NewValueNode(prim::kPrimReduceScatter->Clone()), matmul2_a});
356
357 // branch_b: split_cnode->TupleGetItem(0)->LayerNorm->...->ReduceScatter->MakeTuple->Concat
358 auto get_slice_b = NewTupleGetItemCNodeAndSetAbstract(func_graph, split_cnode, 1);
359 auto layernorm_b =
360 NewLayerNormCNodeAndCloneAttrsSetSliceAbstract(func_graph, layernorm_cnode,
361 {NewValueNode(prim::kPrimLayerNorm->Clone()), get_slice_b,
362 layernorm_cnode->input(kIndex2), layernorm_cnode->input(kIndex3)});
363 auto tuple_get_item_b = NewCNodeAndCloneAttrsSetSliceAbstract(
364 func_graph, tuple_get_item_cnode,
365 {NewValueNode(prim::kPrimTupleGetItem->Clone()), layernorm_b, tuple_get_item_cnode->input(kIndex2)});
366 auto cast_b = NewCNodeAndCloneAttrsSetSliceAbstract(
367 func_graph, cast_cnode, {NewValueNode(prim::kPrimCast->Clone()), tuple_get_item_b, cast_cnode->input(kIndex2)});
368 auto allgather_b = NewCNodeAndCloneAttrsSetSliceAbstract(func_graph, allgather_cnode,
369 {NewValueNode(prim::kPrimAllGather->Clone()), cast_b});
370 auto matmul1_b = NewCNodeAndCloneAttrsSetSliceAbstract(
371 func_graph, matmul1_cnode, {NewValueNode(prim::kPrimMatMul->Clone()), allgather_b, matmul1_cnode->input(kIndex2)});
372 auto add_b = NewCNodeAndCloneAttrsSetSliceAbstract(
373 func_graph, add_cnode, {NewValueNode(prim::kPrimAdd->Clone()), matmul1_b, add_cnode->input(kIndex2)});
374 auto fast_gelu_b = NewCNodeAndCloneAttrsSetSliceAbstract(func_graph, fast_gelu_cnode,
375 {NewValueNode(prim::kPrimFastGeLU->Clone()), add_b});
376 auto matmul2_b = NewCNodeAndCloneAttrsSetSliceAbstract(
377 func_graph, matmul2_cnode, {NewValueNode(prim::kPrimMatMul->Clone()), fast_gelu_b, matmul2_cnode->input(kIndex2)});
378 auto reduce_scatter_b = NewCNodeAndCloneAttrsSetSliceAbstract(
379 func_graph, reduce_scatter_cnode, {NewValueNode(prim::kPrimReduceScatter->Clone()), matmul2_b});
380
381 // Insert depend node
382 InsertDependAndSetAbstract(func_graph, manager, allgather_a, allgather_b);
383 InsertDependAndSetAbstract(func_graph, manager, reduce_scatter_a, reduce_scatter_b);
384 InsertDependAndSetAbstract(func_graph, manager, allgather_b, reduce_scatter_a);
385
386 // New concat(MakeTuple(reduce_scatter_a, reduce_scatter_b))
387 auto concat_cnode = NewConcatCNodeAndSetAbstract(func_graph, reduce_scatter_a, reduce_scatter_b, kLongZero, kLongTwo);
388
389 // Replace graph
390 auto prev_cnode = layernorm_cnode->input(kIndex1);
391 manager->SetEdge(split_cnode, kIndex1, prev_cnode);
392 auto next_cnode_users = manager->node_users()[reduce_scatter_cnode];
393 for (const auto &next_cnode_pair : next_cnode_users) {
394 manager->SetEdge(next_cnode_pair.first, next_cnode_pair.second, concat_cnode);
395 }
396
397 // Expand slice range by white list
398 ExpandSliceRangeToLeft(func_graph, manager, split_cnode);
399 ExpandSliceRangeToRight(func_graph, manager, concat_cnode);
400 }
401
SplitLayerNormCommFp(const FuncGraphPtr & func_graph)402 void SplitLayerNormCommFp(const FuncGraphPtr &func_graph) {
403 if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
404 parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
405 return;
406 }
407
408 auto ms_context = MsContext::GetInstance();
409 MS_EXCEPTION_IF_NULL(ms_context);
410 auto is_enable = ms_context->get_param<bool>(MS_CTX_INTERLEAVED_LAYERNORM_COMM);
411 if (!is_enable) {
412 return;
413 }
414
415 MS_EXCEPTION_IF_NULL(func_graph);
416 auto manager = func_graph->manager();
417 MS_EXCEPTION_IF_NULL(manager);
418 auto todo = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, PatternFilter);
419 for (const auto &node : todo) {
420 SplitIntoInterleaved(func_graph, manager, node);
421 }
422 }
423 } // namespace parallel
424 } // namespace mindspore
425