• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <utility>
18 #include <vector>
19 #include <memory>
20 
21 #include "mindspore/core/ops/other_ops.h"
22 #include "mindspore/core/ops/array_ops.h"
23 #include "mindspore/core/ops/math_ops.h"
24 #include "mindspore/core/ops/nn_optimizer_ops.h"
25 #include "mindspore/core/ops/nn_ops.h"
26 #include "frontend/parallel/pass/split_layernorm_comm_fp.h"
27 #include "frontend/parallel/step_parallel.h"
28 #include "include/common/utils/utils.h"
29 #include "ir/pattern_matcher.h"
30 
31 namespace mindspore {
32 namespace parallel {
33 namespace {
34 constexpr int64_t kLongZero = 0;
35 constexpr int64_t kLongOne = 1;
36 constexpr int64_t kLongTwo = 2;
37 using PrimitiveIndex = std::pair<PrimitivePtr, int>;
38 
IsAnyMatMulInputTranspose(const CNodePtr & matmul_cnode)39 bool IsAnyMatMulInputTranspose(const CNodePtr &matmul_cnode) {
40   MS_EXCEPTION_IF_NULL(matmul_cnode);
41   if (!IsPrimitiveCNode(matmul_cnode)) {
42     return false;
43   }
44   return GetValue<bool>(GetCNodePrimitive(matmul_cnode)->GetAttr("transpose_a")) ||
45          GetValue<bool>(GetCNodePrimitive(matmul_cnode)->GetAttr("transpose_b"));
46 }
47 
CopyAllAttrs(const CNodePtr & dst_cnode,const CNodePtr & src_cnode)48 void CopyAllAttrs(const CNodePtr &dst_cnode, const CNodePtr &src_cnode) {
49   MS_EXCEPTION_IF_NULL(dst_cnode);
50   MS_EXCEPTION_IF_NULL(src_cnode);
51   dst_cnode->set_attrs(src_cnode->attrs());
52   auto dst_prim_node = GetCNodePrimitive(dst_cnode);
53   auto src_prim_node = GetCNodePrimitive(src_cnode);
54   auto src_attrs = src_prim_node->attrs();
55   for (const auto &attr : src_attrs) {
56     dst_prim_node->set_attr(attr.first, attr.second);
57   }
58 }
59 
GetSliceShape(const ShapeVector & origin_shape,size_t slice_axis=0,int64_t slice_size=2)60 ShapeVector GetSliceShape(const ShapeVector &origin_shape, size_t slice_axis = 0, int64_t slice_size = 2) {
61   if (slice_axis >= origin_shape.size()) {
62     MS_LOG(EXCEPTION) << "The slice_axis must be less than origin_shape.size(), but got " << slice_axis << " and "
63                       << origin_shape.size();
64   }
65   if (slice_size == 0) {
66     MS_LOG(EXCEPTION) << "The input 'slice_size' must be a positive integer, but got " << slice_size;
67   }
68   if (origin_shape[slice_axis] % slice_size != 0) {
69     MS_LOG(EXCEPTION) << "The slice_size must be divisible int origin_shape[" << slice_axis << "], but got "
70                       << slice_size << " and " << origin_shape[slice_axis];
71   }
72   auto slice_shape = origin_shape;
73   slice_shape[slice_axis] /= slice_size;
74   return slice_shape;
75 }
76 
77 // Only for single outputs cnode
NewCNodeAndCloneAttrsSetSliceAbstract(const FuncGraphPtr & func_graph,const CNodePtr & src_cnode,std::vector<AnfNodePtr> && inputs,size_t slice_axis=0)78 CNodePtr NewCNodeAndCloneAttrsSetSliceAbstract(const FuncGraphPtr &func_graph, const CNodePtr &src_cnode,
79                                                std::vector<AnfNodePtr> &&inputs, size_t slice_axis = 0) {
80   auto cnode = func_graph->NewCNode(inputs);
81   CopyAllAttrs(cnode, src_cnode);
82 
83   // set abstract
84   ShapeVector src_cnode_shape = common::AnfAlgo::GetOutputInferShape(src_cnode, slice_axis);
85   ShapeVector slice_shape = GetSliceShape(src_cnode_shape, slice_axis);
86   common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetOutputInferDataType(src_cnode, slice_axis)},
87                                                {std::make_shared<abstract::Shape>(slice_shape)}, cnode.get());
88   return cnode;
89 }
90 
NewTupleGetItemCNodeAndSetAbstract(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const int64_t index)91 CNodePtr NewTupleGetItemCNodeAndSetAbstract(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
92                                             const int64_t index) {
93   MS_EXCEPTION_IF_NULL(input_node);
94   auto tuple_get_item_cnode =
95     func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem->Clone()), input_node, NewValueNode(MakeValue(index))});
96 
97   auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, LongToSize(index));
98   auto slice_shape = common::AnfAlgo::GetOutputInferShape(input_node, LongToSize(index));
99   auto slice_shape_abstract = std::make_shared<abstract::Shape>(slice_shape);
100   common::AnfAlgo::SetOutputTypeAndDetailShape({dtype}, {slice_shape_abstract}, tuple_get_item_cnode.get());
101   return tuple_get_item_cnode;
102 }
103 
NewLayerNormCNodeAndCloneAttrsSetSliceAbstract(const FuncGraphPtr & func_graph,const CNodePtr & src_cnode,std::vector<AnfNodePtr> && inputs)104 CNodePtr NewLayerNormCNodeAndCloneAttrsSetSliceAbstract(const FuncGraphPtr &func_graph, const CNodePtr &src_cnode,
105                                                         std::vector<AnfNodePtr> &&inputs) {
106   MS_EXCEPTION_IF_NULL(src_cnode);
107   auto layernorm_cnode = func_graph->NewCNode(inputs);
108   CopyAllAttrs(layernorm_cnode, src_cnode);
109 
110   // set abstract
111   auto slice_shape = GetSliceShape(common::AnfAlgo::GetOutputInferShape(src_cnode, kIndex0));
112   auto slice_shape_abstract = std::make_shared<abstract::Shape>(slice_shape);
113   common::AnfAlgo::SetOutputTypeAndDetailShape(
114     {common::AnfAlgo::GetOutputInferDataType(src_cnode, kIndex0),
115      common::AnfAlgo::GetOutputInferDataType(src_cnode, kIndex1),
116      common::AnfAlgo::GetOutputInferDataType(src_cnode, kIndex2)},
117     {slice_shape_abstract, std::make_shared<abstract::Shape>(ShapeVector{slice_shape[kIndex0], kLongOne}),
118      std::make_shared<abstract::Shape>(ShapeVector{slice_shape[kIndex0], kLongOne})},
119     layernorm_cnode.get());
120 
121   return layernorm_cnode;
122 }
123 
NewSplitCNodeAndSetAbstract(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,int64_t axis,int64_t output_num)124 CNodePtr NewSplitCNodeAndSetAbstract(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, int64_t axis,
125                                      int64_t output_num) {
126   MS_EXCEPTION_IF_NULL(input_node);
127   auto input_node_shape = BaseShapeToShape(AnfAlgo::GetOutputDetailShape(input_node, 0));
128   if (output_num == 0) {
129     MS_LOG(EXCEPTION) << "The input 'output_num' must be a positive integer, but got " << output_num;
130   }
131   if (LongToSize(axis) >= input_node_shape.size() || input_node_shape[axis] % output_num != 0) {
132     return nullptr;
133   }
134   auto split_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimSplit->Clone()), input_node,
135                                            NewValueNode<int64_t>(axis), NewValueNode<int64_t>(output_num)});
136   auto input_shape = common::AnfAlgo::GetOutputInferShape(input_node, kIndex0);
137   int64_t slice_size = input_shape[kIndex0] / kLongTwo;
138   AddCNodePrimAttr(split_cnode, kAttrSizeSplits, MakeValue(ShapeVector{slice_size, slice_size}));
139   AddCNodePrimAttr(split_cnode, kAttrNumSplit, MakeValue(output_num));
140 
141   auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, kIndex0);
142   ShapeVector slice_shape = GetSliceShape(input_shape);
143   auto slice_shape_abstract = std::make_shared<abstract::Shape>(slice_shape);
144   common::AnfAlgo::SetOutputTypeAndDetailShape({dtype, dtype}, {slice_shape_abstract, slice_shape_abstract},
145                                                split_cnode.get());
146   return split_cnode;
147 }
148 
NewConcatCNodeAndSetAbstract(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node0,const AnfNodePtr & input_node1,int64_t axis,int64_t input_num)149 CNodePtr NewConcatCNodeAndSetAbstract(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node0,
150                                       const AnfNodePtr &input_node1, int64_t axis, int64_t input_num) {
151   auto make_tuple_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple->Clone()), input_node0, input_node1});
152   auto concat_cnode =
153     func_graph->NewCNode({NewValueNode(prim::kPrimConcat->Clone()), make_tuple_cnode, NewValueNode(MakeValue(axis))});
154 
155   auto input_node0_dtype = common::AnfAlgo::GetOutputInferDataType(input_node0, kIndex0);
156   auto input_node0_shape = common::AnfAlgo::GetOutputInferShape(input_node0, kIndex0);
157   auto input_node1_dtype = common::AnfAlgo::GetOutputInferDataType(input_node1, kIndex0);
158   auto input_node1_shape = common::AnfAlgo::GetOutputInferShape(input_node1, kIndex0);
159   common::AnfAlgo::SetOutputTypeAndDetailShape(
160     {input_node0_dtype, input_node1_dtype},
161     {std::make_shared<abstract::Shape>(input_node0_shape), std::make_shared<abstract::Shape>(input_node1_shape)},
162     make_tuple_cnode.get());
163   auto concat_shape = input_node0_shape;
164   concat_shape[axis] += input_node1_shape[axis];
165   common::AnfAlgo::SetOutputTypeAndDetailShape({input_node0_dtype}, {std::make_shared<abstract::Shape>(concat_shape)},
166                                                concat_cnode.get());
167   return concat_cnode;
168 }
169 
InsertDependAndSetAbstract(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const AnfNodePtr & prior_node,const AnfNodePtr & post_node)170 void InsertDependAndSetAbstract(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
171                                 const AnfNodePtr &prior_node, const AnfNodePtr &post_node) {
172   MS_EXCEPTION_IF_NULL(prior_node);
173   MS_EXCEPTION_IF_NULL(post_node);
174   auto post_cnode = post_node->cast<CNodePtr>();
175   MS_EXCEPTION_IF_NULL(post_cnode);
176   std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), post_cnode->input(kIndex1), prior_node};
177   auto depend_cnode = func_graph->NewCNode(depend_inputs);
178   manager->SetEdge(post_node, kIndex1, depend_cnode);
179   depend_cnode->set_abstract(depend_cnode->input(kIndex1)->abstract());
180 }
181 }  // namespace
182 
IsForwardCNode(const CNodePtr & cnode)183 static bool IsForwardCNode(const CNodePtr &cnode) {
184   MS_EXCEPTION_IF_NULL(cnode);
185   return !(cnode->HasPrimalAttr(kPrimalAttrForwardUniqueId) || cnode->HasAttr(kAttrDuplicated));
186 }
187 
IsCareNode(const AnfNodePtr & node)188 static bool IsCareNode(const AnfNodePtr &node) {
189   MS_EXCEPTION_IF_NULL(node);
190   auto cnode = node->cast<CNodePtr>();
191   MS_EXCEPTION_IF_NULL(cnode);
192   if (!IsOneOfPrimitiveCNode(cnode, {prim::kPrimCast, prim::kPrimGeLU, prim::kPrimFastGeLU})) {
193     return false;
194   }
195   const auto &node_user = cnode->func_graph()->manager()->node_users()[cnode];
196   return node_user.size() == kSizeOne;
197 }
198 
199 // LayerNorm->TupleGetItem->Cast->AllGather->Matmul->Add->Activation->MatMul->ReduceScatter
PatternFilter(const AnfNodePtr & node)200 static bool PatternFilter(const AnfNodePtr &node) {
201   auto cnode = node->cast<CNodePtr>();
202   if (cnode == nullptr || !IsForwardCNode(cnode)) {
203     return true;
204   }
205   if (!IsPrimitiveCNode(cnode, prim::kPrimLayerNorm)) {
206     return true;
207   }
208 
209   static std::vector<PrimitiveIndex> expect_primitive_list = {
210     {prim::kPrimTupleGetItem, 1}, {prim::kPrimCast, 1},     {prim::kPrimAllGather, 1}, {prim::kPrimMatMul, 1},
211     {prim::kPrimAdd, 1},          {prim::kPrimFastGeLU, 1}, {prim::kPrimMatMul, 1},    {prim::kPrimReduceScatter, 1}};
212   AnfNodePtr cur_node = node;
213   for (const auto &expect_prim : expect_primitive_list) {
214     auto cur_cnode = cur_node->cast<CNodePtr>();
215     auto output_node_set = cur_cnode->func_graph()->manager()->node_users()[cur_cnode];
216     if (output_node_set.size() != kSizeOne) {
217       return true;
218     }
219     auto index = output_node_set.front().second;
220     if (index != expect_prim.second) {
221       return true;
222     }
223     auto next_node = output_node_set.front().first;
224     auto next_cnode = next_node->cast<CNodePtr>();
225     if (next_cnode == nullptr || !IsForwardCNode(next_cnode)) {
226       return true;
227     }
228     if (!IsPrimitiveCNode(next_cnode, expect_prim.first)) {
229       return true;
230     }
231     cur_node = next_node;
232   }
233   return false;
234 }
235 
ExpandSliceRangeToLeft(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const CNodePtr & split_cnode)236 static void ExpandSliceRangeToLeft(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
237                                    const CNodePtr &split_cnode) {
238   MS_EXCEPTION_IF_NULL(split_cnode);
239   auto pre_node = split_cnode->input(kIndex1);
240   std::shared_ptr<abstract::Shape> slice_pre_cnode_shape_abstract;
241   while (IsCareNode(pre_node)) {
242     auto pre_cnode = pre_node->cast<CNodePtr>();
243     MS_EXCEPTION_IF_NULL(pre_cnode);
244     auto pre_cnode_dtype = common::AnfAlgo::GetOutputInferDataType(pre_cnode, kIndex0);
245     auto pre_cnode_shape = common::AnfAlgo::GetOutputInferShape(pre_cnode, kIndex0);
246     auto slice_pre_cnode_shape = {pre_cnode_shape[kIndex0] / kLongTwo, pre_cnode_shape[kIndex1]};
247     slice_pre_cnode_shape_abstract = std::make_shared<abstract::Shape>(slice_pre_cnode_shape);
248     auto pre_node_prim = GetCNodePrimitive(pre_cnode);
249     auto node_users = manager->node_users()[split_cnode];
250     for (const auto &node_user : node_users) {
251       auto tuple_get_item_node = node_user.first;
252       auto tuple_get_item_node_users = manager->node_users()[tuple_get_item_node];
253       auto pre_cnode_sub = func_graph->NewCNode({NewValueNode(pre_node_prim->Clone()), tuple_get_item_node});
254       common::AnfAlgo::SetOutputTypeAndDetailShape({pre_cnode_dtype}, {slice_pre_cnode_shape_abstract},
255                                                    pre_cnode_sub.get());
256       for (const auto &tuple_get_item_node_user : tuple_get_item_node_users) {
257         manager->SetEdge(tuple_get_item_node_user.first, tuple_get_item_node_user.second, pre_cnode_sub);
258       }
259     }
260 
261     manager->SetEdge(split_cnode, kIndex1, pre_cnode->input(kIndex1));
262     pre_node = split_cnode->input(kIndex1);
263   }
264   if (slice_pre_cnode_shape_abstract == nullptr) {
265     return;
266   }
267   // Refresh abstract for split and tuple_get_item
268   auto new_split_cnode_dtype = common::AnfAlgo::GetOutputInferDataType(split_cnode->input(kIndex1), kIndex0);
269   common::AnfAlgo::SetOutputTypeAndDetailShape({new_split_cnode_dtype, new_split_cnode_dtype},
270                                                {slice_pre_cnode_shape_abstract, slice_pre_cnode_shape_abstract},
271                                                split_cnode.get());
272   auto node_users = manager->node_users()[split_cnode];
273   for (const auto &node_user : node_users) {
274     common::AnfAlgo::SetOutputTypeAndDetailShape({new_split_cnode_dtype}, {slice_pre_cnode_shape_abstract},
275                                                  node_user.first.get());
276   }
277 }
278 
ExpandSliceRangeToRight(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const CNodePtr & concat_cnode)279 static void ExpandSliceRangeToRight(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
280                                     const CNodePtr &concat_cnode) {
281   while (true) {
282     const auto &node_users = manager->node_users()[concat_cnode];
283     if (node_users.size() != kSizeOne) {
284       return;
285     }
286     auto next_cnode = node_users.front().first->cast<CNodePtr>();
287     MS_EXCEPTION_IF_NULL(next_cnode);
288     auto next_cnode_prim = GetCNodePrimitive(next_cnode);
289     if (!IsOneOfPrimitiveCNode(next_cnode, {prim::kPrimCast, prim::kPrimGeLU, prim::kPrimFastGeLU})) {
290       return;
291     }
292 
293     auto make_tuple_cnode = concat_cnode->input(kIndex1)->cast<CNodePtr>();
294     MS_EXCEPTION_IF_NULL(make_tuple_cnode);
295 
296     for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
297       auto input_node = make_tuple_cnode->input(i);
298       auto next_cnode_sub = func_graph->NewCNode({NewValueNode(next_cnode_prim->Clone()), input_node});
299       next_cnode_sub->set_abstract(input_node->abstract());
300       manager->SetEdge(make_tuple_cnode, SizeToInt(i), next_cnode_sub);
301     }
302     auto next_cnode_users = manager->node_users()[next_cnode];
303     for (const auto &next_cnode_pair : next_cnode_users) {
304       manager->SetEdge(next_cnode_pair.first, next_cnode_pair.second, concat_cnode);
305     }
306   }
307 }
308 
SplitIntoInterleaved(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const AnfNodePtr & layernorm_node)309 static void SplitIntoInterleaved(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
310                                  const AnfNodePtr &layernorm_node) {
311   auto layernorm_cnode = layernorm_node->cast<CNodePtr>();
312   auto tuple_get_item_cnode = manager->node_users()[layernorm_cnode].front().first->cast<CNodePtr>();
313   auto cast_cnode = manager->node_users()[tuple_get_item_cnode].front().first->cast<CNodePtr>();
314   auto allgather_cnode = manager->node_users()[cast_cnode].front().first->cast<CNodePtr>();
315   auto matmul1_cnode = manager->node_users()[allgather_cnode].front().first->cast<CNodePtr>();
316   if (IsAnyMatMulInputTranspose(matmul1_cnode)) {
317     return;
318   }
319   auto add_cnode = manager->node_users()[matmul1_cnode].front().first->cast<CNodePtr>();
320   auto fast_gelu_cnode = manager->node_users()[add_cnode].front().first->cast<CNodePtr>();
321   auto matmul2_cnode = manager->node_users()[fast_gelu_cnode].front().first->cast<CNodePtr>();
322   if (IsAnyMatMulInputTranspose(matmul2_cnode)) {
323     return;
324   }
325   auto reduce_scatter_cnode = manager->node_users()[matmul2_cnode].front().first->cast<CNodePtr>();
326 
327   // New split(layernorm_input1, 0, 2)
328   auto split_cnode = NewSplitCNodeAndSetAbstract(func_graph, layernorm_cnode->input(kIndex1), kLongZero, kLongTwo);
329   if (split_cnode == nullptr) {
330     return;
331   }
332 
333   // branch_a: split_cnode->TupleGetItem(0)->LayerNorm->...->ReduceScatter->MakeTuple->Concat
334   auto get_slice_a = NewTupleGetItemCNodeAndSetAbstract(func_graph, split_cnode, 0);
335   auto layernorm_a =
336     NewLayerNormCNodeAndCloneAttrsSetSliceAbstract(func_graph, layernorm_cnode,
337                                                    {NewValueNode(prim::kPrimLayerNorm->Clone()), get_slice_a,
338                                                     layernorm_cnode->input(kIndex2), layernorm_cnode->input(kIndex3)});
339   auto tuple_get_item_a = NewCNodeAndCloneAttrsSetSliceAbstract(
340     func_graph, tuple_get_item_cnode,
341     {NewValueNode(prim::kPrimTupleGetItem->Clone()), layernorm_a, tuple_get_item_cnode->input(kIndex2)});
342   auto cast_a = NewCNodeAndCloneAttrsSetSliceAbstract(
343     func_graph, cast_cnode, {NewValueNode(prim::kPrimCast->Clone()), tuple_get_item_a, cast_cnode->input(kIndex2)});
344   auto allgather_a = NewCNodeAndCloneAttrsSetSliceAbstract(func_graph, allgather_cnode,
345                                                            {NewValueNode(prim::kPrimAllGather->Clone()), cast_a});
346   auto matmul1_a = NewCNodeAndCloneAttrsSetSliceAbstract(
347     func_graph, matmul1_cnode, {NewValueNode(prim::kPrimMatMul->Clone()), allgather_a, matmul1_cnode->input(kIndex2)});
348   auto add_a = NewCNodeAndCloneAttrsSetSliceAbstract(
349     func_graph, add_cnode, {NewValueNode(prim::kPrimAdd->Clone()), matmul1_a, add_cnode->input(kIndex2)});
350   auto fast_gelu_a = NewCNodeAndCloneAttrsSetSliceAbstract(func_graph, fast_gelu_cnode,
351                                                            {NewValueNode(prim::kPrimFastGeLU->Clone()), add_a});
352   auto matmul2_a = NewCNodeAndCloneAttrsSetSliceAbstract(
353     func_graph, matmul2_cnode, {NewValueNode(prim::kPrimMatMul->Clone()), fast_gelu_a, matmul2_cnode->input(kIndex2)});
354   auto reduce_scatter_a = NewCNodeAndCloneAttrsSetSliceAbstract(
355     func_graph, reduce_scatter_cnode, {NewValueNode(prim::kPrimReduceScatter->Clone()), matmul2_a});
356 
357   // branch_b: split_cnode->TupleGetItem(0)->LayerNorm->...->ReduceScatter->MakeTuple->Concat
358   auto get_slice_b = NewTupleGetItemCNodeAndSetAbstract(func_graph, split_cnode, 1);
359   auto layernorm_b =
360     NewLayerNormCNodeAndCloneAttrsSetSliceAbstract(func_graph, layernorm_cnode,
361                                                    {NewValueNode(prim::kPrimLayerNorm->Clone()), get_slice_b,
362                                                     layernorm_cnode->input(kIndex2), layernorm_cnode->input(kIndex3)});
363   auto tuple_get_item_b = NewCNodeAndCloneAttrsSetSliceAbstract(
364     func_graph, tuple_get_item_cnode,
365     {NewValueNode(prim::kPrimTupleGetItem->Clone()), layernorm_b, tuple_get_item_cnode->input(kIndex2)});
366   auto cast_b = NewCNodeAndCloneAttrsSetSliceAbstract(
367     func_graph, cast_cnode, {NewValueNode(prim::kPrimCast->Clone()), tuple_get_item_b, cast_cnode->input(kIndex2)});
368   auto allgather_b = NewCNodeAndCloneAttrsSetSliceAbstract(func_graph, allgather_cnode,
369                                                            {NewValueNode(prim::kPrimAllGather->Clone()), cast_b});
370   auto matmul1_b = NewCNodeAndCloneAttrsSetSliceAbstract(
371     func_graph, matmul1_cnode, {NewValueNode(prim::kPrimMatMul->Clone()), allgather_b, matmul1_cnode->input(kIndex2)});
372   auto add_b = NewCNodeAndCloneAttrsSetSliceAbstract(
373     func_graph, add_cnode, {NewValueNode(prim::kPrimAdd->Clone()), matmul1_b, add_cnode->input(kIndex2)});
374   auto fast_gelu_b = NewCNodeAndCloneAttrsSetSliceAbstract(func_graph, fast_gelu_cnode,
375                                                            {NewValueNode(prim::kPrimFastGeLU->Clone()), add_b});
376   auto matmul2_b = NewCNodeAndCloneAttrsSetSliceAbstract(
377     func_graph, matmul2_cnode, {NewValueNode(prim::kPrimMatMul->Clone()), fast_gelu_b, matmul2_cnode->input(kIndex2)});
378   auto reduce_scatter_b = NewCNodeAndCloneAttrsSetSliceAbstract(
379     func_graph, reduce_scatter_cnode, {NewValueNode(prim::kPrimReduceScatter->Clone()), matmul2_b});
380 
381   // Insert depend node
382   InsertDependAndSetAbstract(func_graph, manager, allgather_a, allgather_b);
383   InsertDependAndSetAbstract(func_graph, manager, reduce_scatter_a, reduce_scatter_b);
384   InsertDependAndSetAbstract(func_graph, manager, allgather_b, reduce_scatter_a);
385 
386   // New concat(MakeTuple(reduce_scatter_a, reduce_scatter_b))
387   auto concat_cnode = NewConcatCNodeAndSetAbstract(func_graph, reduce_scatter_a, reduce_scatter_b, kLongZero, kLongTwo);
388 
389   // Replace graph
390   auto prev_cnode = layernorm_cnode->input(kIndex1);
391   manager->SetEdge(split_cnode, kIndex1, prev_cnode);
392   auto next_cnode_users = manager->node_users()[reduce_scatter_cnode];
393   for (const auto &next_cnode_pair : next_cnode_users) {
394     manager->SetEdge(next_cnode_pair.first, next_cnode_pair.second, concat_cnode);
395   }
396 
397   // Expand slice range by white list
398   ExpandSliceRangeToLeft(func_graph, manager, split_cnode);
399   ExpandSliceRangeToRight(func_graph, manager, concat_cnode);
400 }
401 
SplitLayerNormCommFp(const FuncGraphPtr & func_graph)402 void SplitLayerNormCommFp(const FuncGraphPtr &func_graph) {
403   if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
404       parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
405     return;
406   }
407 
408   auto ms_context = MsContext::GetInstance();
409   MS_EXCEPTION_IF_NULL(ms_context);
410   auto is_enable = ms_context->get_param<bool>(MS_CTX_INTERLEAVED_LAYERNORM_COMM);
411   if (!is_enable) {
412     return;
413   }
414 
415   MS_EXCEPTION_IF_NULL(func_graph);
416   auto manager = func_graph->manager();
417   MS_EXCEPTION_IF_NULL(manager);
418   auto todo = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, PatternFilter);
419   for (const auto &node : todo) {
420     SplitIntoInterleaved(func_graph, manager, node);
421   }
422 }
423 }  // namespace parallel
424 }  // namespace mindspore
425