• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.h"
18 #include <vector>
19 #include <string>
20 #include "ops/other_ops.h"
21 #include "ops/array_ops.h"
22 #include "utils/trace_base.h"
23 #include "include/common/utils/anfalgo.h"
24 #include "include/common/utils/comm_manager.h"
25 #include "include/backend/optimizer/helper.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 
29 namespace mindspore {
30 namespace opt {
31 namespace {
32 constexpr size_t kCNodePrimitiveIdx = 0;
33 constexpr size_t kAllToAllInputIdx = 1;
34 constexpr auto kAttrIrUnified = "ir_unified";
35 constexpr auto kAttrFlashIndex = "FLASH_INDEX";
36 
ChangePrimitiveToAllToAllV(const AnfNodePtr & node)37 void ChangePrimitiveToAllToAllV(const AnfNodePtr &node) {
38   MS_EXCEPTION_IF_NULL(node);
39   auto neighbor_exchange = node->cast<CNodePtr>();
40   MS_EXCEPTION_IF_NULL(neighbor_exchange);
41 
42   if (neighbor_exchange->size() == kCNodePrimitiveIdx) {
43     MS_LOG(INTERNAL_EXCEPTION) << "Inputs should not be empty for cnode " << node->DebugString()
44                                << trace::DumpSourceLines(neighbor_exchange);
45   }
46 
47   auto prim = GetValueNode<PrimitivePtr>(neighbor_exchange->input(kCNodePrimitiveIdx));
48   MS_EXCEPTION_IF_NULL(prim);
49   prim->Named::operator=(Named(kAllToAllvOpName));
50 }
51 
GetRankSize(const std::string & group)52 uint32_t GetRankSize(const std::string &group) {
53   uint32_t rank_size;
54   if (!CommManager::GetInstance().GetRankSize(group, &rank_size)) {
55     MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed.";
56   }
57   return rank_size;
58 }
59 }  // namespace
60 
CreateSplitNode(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const AnfNodePtr & input_node,int64_t split_count,int64_t split_dim) const61 CNodePtr AllToAllUnifyMindIR::CreateSplitNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
62                                               const AnfNodePtr &input_node, int64_t split_count,
63                                               int64_t split_dim) const {
64   MS_EXCEPTION_IF_NULL(graph);
65   MS_EXCEPTION_IF_NULL(all_to_all);
66 
67   std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
68                                          input_node, graph->NewValueNode(MakeValue(split_dim)),
69                                          graph->NewValueNode(MakeValue(split_count))};
70   auto split = NewCNode(split_input, graph);
71   MS_EXCEPTION_IF_NULL(split);
72   split->set_scope(all_to_all->scope());
73   auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
74   auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
75   auto shape_size = SizeToLong(shape.size());
76   if (split_dim >= shape_size || split_dim < -shape_size) {
77     MS_LOG(INTERNAL_EXCEPTION) << "Invalid split dim " << split_dim << " is over the shape size " << shape.size()
78                                << trace::DumpSourceLines(all_to_all);
79   }
80   size_t split_idx = split_dim < 0 ? LongToSize(split_dim + shape_size) : LongToSize(split_dim);
81   if (shape[split_idx] >= 0 && (split_count == 0 || shape[split_idx] % split_count != 0)) {
82     MS_LOG(INTERNAL_EXCEPTION) << "Invalid split count " << split_count << " cannot be divisible by shape[" << split_idx
83                                << "] = " << shape[split_idx] << trace::DumpSourceLines(all_to_all);
84   }
85   shape[split_idx] = shape[split_idx] >= 0 ? shape[split_idx] / split_count : shape[split_idx];
86   std::vector<TypeId> dtypes(split_count, dtype);
87   std::vector<ShapeVector> shapes(split_count, shape);
88   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
89 
90   common::AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split);
91   return split;
92 }
93 
CreateAllToAllvNode(const FuncGraphPtr & graph,const CNodePtr & neighbor_exchange) const94 CNodePtr NeighborExchangeUnifyMindIR::CreateAllToAllvNode(const FuncGraphPtr &graph,
95                                                           const CNodePtr &neighbor_exchange) const {
96   std::string group = common::AnfAlgo::GetNodeAttr<std::string>(neighbor_exchange, kAttrGroup);
97   std::vector<uint32_t> group_rank_ids =
98     common::AnfAlgo::HasNodeAttr(kAttrGroupRankIds, neighbor_exchange)
99       ? common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(neighbor_exchange, kAttrGroupRankIds)
100       : std::vector<uint32_t>();
101   std::vector<int64_t> send_rank_ids =
102     common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange, kAttrSendRankIds);
103   std::vector<int64_t> recv_rank_ids =
104     common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange, kAttrRecvRankIds);
105 
106   int64_t send_count = send_rank_ids.size(), recv_count = recv_rank_ids.size();
107   auto tuple_input = neighbor_exchange->input(1);
108   std::vector<AnfNodePtr> split_outputs;
109   CreateMultipleOutputsOfAnfNode(graph, tuple_input, static_cast<size_t>(send_count), &split_outputs);
110   if (split_outputs.empty()) {
111     MS_LOG(INTERNAL_EXCEPTION) << "The node " << tuple_input->DebugString()
112                                << " should have at least one output, but got 0." << trace::DumpSourceLines(tuple_input);
113   }
114   std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllvOpName))};
115   (void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
116   auto all_to_all_v = NewCNode(all_to_all_v_input, graph);
117   MS_EXCEPTION_IF_NULL(all_to_all_v);
118 
119   auto single_shape = AnfAlgo::GetOutputDetailShape(split_outputs[0], 0UL);
120   auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0UL);
121   std::vector<TypeId> dtypes(recv_count, single_type);
122   std::vector<BaseShapePtr> shapes(recv_count, single_shape);
123   common::AnfAlgo::SetSingleOutputTypeAndDetailShape(dtypes, shapes, all_to_all_v.get());
124 
125   common::AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(send_rank_ids), all_to_all_v);
126   common::AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(recv_rank_ids), all_to_all_v);
127   common::AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), all_to_all_v);
128   common::AnfAlgo::SetNodeAttr(kAttrGroupRankIds, MakeValue<std::vector<uint32_t>>(group_rank_ids), all_to_all_v);
129 
130   auto neighbor_exchange_prim = GetCNodePrimitive(neighbor_exchange);
131   MS_EXCEPTION_IF_NULL(neighbor_exchange_prim);
132   if (neighbor_exchange_prim->HasAttr(parallel::COMM_REUSE) &&
133       GetValue<bool>(neighbor_exchange_prim->GetAttr(parallel::COMM_REUSE))) {
134     auto all_to_all_v_prim = GetCNodePrimitive(all_to_all_v);
135     MS_EXCEPTION_IF_NULL(all_to_all_v_prim);
136     (void)all_to_all_v_prim->AddAttr(parallel::COMM_REUSE, MakeValue(true));
137   }
138 
139   if (neighbor_exchange_prim->HasAttr("FLASH_INDEX")) {
140     auto flash_index = GetValue<std::string>(neighbor_exchange_prim->GetAttr("FLASH_INDEX"));
141     auto all_to_all_v_prim = GetCNodePrimitive(all_to_all_v);
142     MS_EXCEPTION_IF_NULL(all_to_all_v_prim);
143     (void)all_to_all_v_prim->AddAttr("FLASH_INDEX", MakeValue<std::string>(flash_index));
144   }
145   return all_to_all_v;
146 }
147 
CreateSplitNodeWithSplitDim(const KernelGraphPtr & graph,const CNodePtr & all_to_all) const148 CNodePtr AllToAllUnifyMindIR::CreateSplitNodeWithSplitDim(const KernelGraphPtr &graph,
149                                                           const CNodePtr &all_to_all) const {
150   MS_EXCEPTION_IF_NULL(all_to_all);
151   int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
152   int64_t split_dim = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitDim);
153 
154   if (all_to_all->size() <= kAllToAllInputIdx) {
155     MS_LOG(EXCEPTION) << "Inputs should not be empty for cnode " << all_to_all->DebugString()
156                       << trace::DumpSourceLines(all_to_all);
157   }
158   auto all_to_all_input = all_to_all->input(kAllToAllInputIdx);
159   return CreateSplitNode(graph, all_to_all, all_to_all_input, split_count, split_dim);
160 }
161 
CreateSplitNodeWithDim0(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & input_node) const162 CNodePtr AllToAllUnifyMindIR::CreateSplitNodeWithDim0(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
163                                                       const CNodePtr &input_node) const {
164   MS_EXCEPTION_IF_NULL(all_to_all);
165   int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
166   return CreateSplitNode(graph, all_to_all, input_node, split_count, 0);
167 }
168 
CreateAllToAllvNode(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & split) const169 CNodePtr AllToAllUnifyMindIR::CreateAllToAllvNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
170                                                   const CNodePtr &split) const {
171   MS_EXCEPTION_IF_NULL(graph);
172   MS_EXCEPTION_IF_NULL(all_to_all);
173   MS_EXCEPTION_IF_NULL(split);
174   int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
175   std::string group = common::AnfAlgo::GetNodeAttr<std::string>(all_to_all, kAttrGroup);
176   std::vector<uint32_t> group_rank_ids =
177     common::AnfAlgo::HasNodeAttr(kAttrGroupRankIds, all_to_all)
178       ? common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(all_to_all, kAttrGroupRankIds)
179       : std::vector<uint32_t>();
180   std::vector<AnfNodePtr> split_outputs;
181   CreateMultipleOutputsOfAnfNode(graph, split, static_cast<size_t>(split_count), &split_outputs);
182   if (split_outputs.empty()) {
183     MS_LOG(INTERNAL_EXCEPTION) << "The node " << split->DebugString() << " should have at least one output, but got 0."
184                                << trace::DumpSourceLines(split);
185   }
186   std::vector<AnfNodePtr> new_ata_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllvOpName))};
187   (void)new_ata_input.insert(new_ata_input.end(), split_outputs.begin(), split_outputs.end());
188   auto new_ata = NewCNode(new_ata_input, graph);
189   MS_EXCEPTION_IF_NULL(new_ata);
190   new_ata->set_scope(all_to_all->scope());
191   auto single_shape = AnfAlgo::GetOutputDetailShape(split_outputs[0], 0UL);
192   auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0UL);
193   std::vector<TypeId> dtypes(split_count, single_type);
194   std::vector<BaseShapePtr> shapes(split_count, single_shape);
195   common::AnfAlgo::SetOutputTypeAndDetailShape(dtypes, shapes, new_ata.get());
196   uint32_t rank_size = GetRankSize(group);
197   std::vector<int64_t> rank_ids(rank_size, 0);
198   for (uint32_t i = 0; i < rank_size; ++i) {
199     rank_ids[i] = static_cast<int64_t>(i);
200   }
201 
202   common::AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(rank_ids), new_ata);
203   common::AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(rank_ids), new_ata);
204   common::AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), new_ata);
205   common::AnfAlgo::SetNodeAttr(kAttrGroupRankIds, MakeValue<std::vector<uint32_t>>(group_rank_ids), new_ata);
206   auto all_to_all_prim = GetCNodePrimitive(all_to_all);
207   MS_EXCEPTION_IF_NULL(all_to_all_prim);
208   if (all_to_all_prim->HasAttr(parallel::COMM_REUSE) &&
209       GetValue<bool>(all_to_all_prim->GetAttr(parallel::COMM_REUSE))) {
210     auto new_ata_prim = GetCNodePrimitive(new_ata);
211     MS_EXCEPTION_IF_NULL(new_ata_prim);
212     (void)new_ata_prim->AddAttr(parallel::COMM_REUSE, MakeValue(true));
213   }
214   MS_LOG(INFO) << "Create AllToAllv success, split count " << split_count << ", rank size " << rank_size;
215   return new_ata;
216 }
217 
CreateAllToAllNode(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & concat) const218 CNodePtr AllToAllUnifyMindIR::CreateAllToAllNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
219                                                  const CNodePtr &concat) const {
220   MS_EXCEPTION_IF_NULL(graph);
221   MS_EXCEPTION_IF_NULL(all_to_all);
222   MS_EXCEPTION_IF_NULL(concat);
223   int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
224   std::string group = common::AnfAlgo::GetNodeAttr<std::string>(all_to_all, kAttrGroup);
225   std::vector<AnfNodePtr> new_ata_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllOpName))};
226   (void)new_ata_input.insert(new_ata_input.end(), concat);
227   auto new_ata = NewCNode(new_ata_input, graph);
228   MS_EXCEPTION_IF_NULL(new_ata);
229   new_ata->set_scope(all_to_all->scope());
230   new_ata->set_abstract(concat->abstract());
231   common::AnfAlgo::CopyNodeAttr(kAttrGroup, all_to_all, new_ata);
232   auto all_to_all_prim = GetCNodePrimitive(all_to_all);
233   MS_EXCEPTION_IF_NULL(all_to_all_prim);
234   if (all_to_all_prim->HasAttr(parallel::COMM_REUSE) &&
235       GetValue<bool>(all_to_all_prim->GetAttr(parallel::COMM_REUSE))) {
236     auto new_ata_prim = GetCNodePrimitive(new_ata);
237     MS_EXCEPTION_IF_NULL(new_ata_prim);
238     (void)new_ata_prim->AddAttr(parallel::COMM_REUSE, MakeValue(true));
239   }
240   common::AnfAlgo::SetNodeAttr(kAttrIrUnified, MakeValue(true), new_ata);
241   uint32_t rank_size = GetRankSize(group);
242   MS_LOG(INFO) << "Create AlltoAll success, split count " << split_count << ", rank size " << rank_size;
243   return new_ata;
244 }
245 
CreateConcatNode(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & input_node,int64_t split_count,int64_t concat_dim) const246 CNodePtr AllToAllUnifyMindIR::CreateConcatNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
247                                                const CNodePtr &input_node, int64_t split_count,
248                                                int64_t concat_dim) const {
249   MS_EXCEPTION_IF_NULL(graph);
250   MS_EXCEPTION_IF_NULL(all_to_all);
251   MS_EXCEPTION_IF_NULL(input_node);
252   std::vector<AnfNodePtr> input_node_outputs;
253   CreateMultipleOutputsOfAnfNode(graph, input_node, static_cast<size_t>(split_count), &input_node_outputs);
254   if (input_node_outputs.empty()) {
255     MS_LOG(INTERNAL_EXCEPTION) << "The node " << input_node->DebugString()
256                                << " should have at least one output, but got 0." << trace::DumpSourceLines(input_node);
257   }
258   std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName)), input_node,
259                                           graph->NewValueNode(MakeValue(concat_dim))};
260   auto concat = NewCNode(concat_input, graph);
261   MS_EXCEPTION_IF_NULL(concat);
262   concat->set_scope(all_to_all->scope());
263   auto single_shape = common::AnfAlgo::GetOutputInferShape(input_node_outputs[0], 0);
264   auto shape_size = SizeToLong(single_shape.size());
265   if (concat_dim >= shape_size || concat_dim < -shape_size) {
266     MS_LOG(INTERNAL_EXCEPTION) << "Invalid concat dim " << concat_dim << " is greater than shape size "
267                                << single_shape.size() << trace::DumpSourceLines(all_to_all);
268   }
269   size_t concat_idx = concat_dim < 0 ? LongToSize(concat_dim + shape_size) : LongToSize(concat_dim);
270   single_shape[concat_idx] =
271     single_shape[concat_idx] >= 0 ? single_shape[concat_idx] * split_count : single_shape[concat_idx];
272   common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(input_node_outputs[0], 0UL)},
273                                               {single_shape}, concat.get());
274   return concat;
275 }
276 
CreateConcatNodeWithConcatDim(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & input_node) const277 CNodePtr AllToAllUnifyMindIR::CreateConcatNodeWithConcatDim(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
278                                                             const CNodePtr &input_node) const {
279   MS_EXCEPTION_IF_NULL(all_to_all);
280   int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
281   int64_t concat_dim = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrConcatDim);
282   return CreateConcatNode(graph, all_to_all, input_node, split_count, concat_dim);
283 }
284 
CreateConcatNodeWithDim0(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & input_node) const285 CNodePtr AllToAllUnifyMindIR::CreateConcatNodeWithDim0(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
286                                                        const CNodePtr &input_node) const {
287   MS_EXCEPTION_IF_NULL(all_to_all);
288   int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
289   return CreateConcatNode(graph, all_to_all, input_node, split_count, 0);
290 }
291 
MustExistPrimitiveName() const292 std::vector<std::string> NeighborExchangeUnifyMindIR::MustExistPrimitiveName() const {
293   std::vector<std::string> ret;
294   ret.emplace_back(prim::kPrimNeighborExchange->name());
295   return ret;
296 }
297 
DefinePattern() const298 const BaseRef NeighborExchangeUnifyMindIR::DefinePattern() const {
299   return VectorRef({prim::kPrimNeighborExchange, std::make_shared<SeqVar>()});
300 }
301 
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const302 const AnfNodePtr NeighborExchangeUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
303                                                       const EquivPtr &) const {
304   MS_EXCEPTION_IF_NULL(graph);
305   MS_EXCEPTION_IF_NULL(node);
306   auto neighbor_exchange = node->cast<CNodePtr>();
307   MS_EXCEPTION_IF_NULL(neighbor_exchange);
308   auto neighbor_exchange_prim = GetCNodePrimitive(neighbor_exchange);
309   MS_EXCEPTION_IF_NULL(neighbor_exchange_prim);
310   if (!neighbor_exchange_prim->HasAttr(kAttrFlashIndex)) {
311     ChangePrimitiveToAllToAllV(node);
312     return node;
313   }
314   auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange);
315   return all_to_all_v;
316 }
317 
MustExistPrimitiveName() const318 std::vector<std::string> AllToAllUnifyMindIR::MustExistPrimitiveName() const {
319   std::vector<std::string> ret;
320   ret.emplace_back(prim::kPrimAlltoAll->name());
321   return ret;
322 }
323 
DefinePattern() const324 const BaseRef AllToAllUnifyMindIR::DefinePattern() const {
325   return VectorRef({prim::kPrimAlltoAll, std::make_shared<SeqVar>()});
326 }
327 
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const328 const AnfNodePtr AllToAllUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
329                                               const EquivPtr &) const {
330   MS_EXCEPTION_IF_NULL(graph);
331   MS_EXCEPTION_IF_NULL(node);
332   auto all_to_all = node->cast<CNodePtr>();
333   MS_EXCEPTION_IF_NULL(all_to_all);
334   if (GetBoolAttr(all_to_all, kAttrIrUnified)) {
335     return nullptr;
336   }
337   auto kernel_graph = graph->cast<KernelGraphPtr>();
338   MS_EXCEPTION_IF_NULL(kernel_graph);
339   auto ms_context = MsContext::GetInstance();
340   bool is_kbk = ms_context->IsKByKExecutorMode() || ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) == false;
341   AnfNodePtr ret_node = nullptr;
342   if (is_kbk) {
343     auto split = CreateSplitNodeWithSplitDim(kernel_graph, all_to_all);
344     auto concat_dim0 = CreateConcatNodeWithDim0(kernel_graph, all_to_all, split);
345     auto new_ata = CreateAllToAllNode(kernel_graph, all_to_all, concat_dim0);
346     auto split_dim0 = CreateSplitNodeWithDim0(kernel_graph, all_to_all, new_ata);
347     auto concat = CreateConcatNodeWithConcatDim(kernel_graph, all_to_all, split_dim0);
348     ret_node = concat;
349   } else {
350     auto split = CreateSplitNodeWithSplitDim(kernel_graph, all_to_all);
351     auto new_ata = CreateAllToAllvNode(kernel_graph, all_to_all, split);
352     auto concat = CreateConcatNodeWithConcatDim(kernel_graph, all_to_all, new_ata);
353     ret_node = concat;
354   }
355   return ret_node;
356 }
357 }  // namespace opt
358 }  // namespace mindspore
359