1 /**
2 * Copyright 2021-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.h"
18 #include <vector>
19 #include <string>
20 #include "ops/other_ops.h"
21 #include "ops/array_ops.h"
22 #include "utils/trace_base.h"
23 #include "include/common/utils/anfalgo.h"
24 #include "include/common/utils/comm_manager.h"
25 #include "include/backend/optimizer/helper.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28
29 namespace mindspore {
30 namespace opt {
31 namespace {
32 constexpr size_t kCNodePrimitiveIdx = 0;
33 constexpr size_t kAllToAllInputIdx = 1;
34 constexpr auto kAttrIrUnified = "ir_unified";
35 constexpr auto kAttrFlashIndex = "FLASH_INDEX";
36
ChangePrimitiveToAllToAllV(const AnfNodePtr & node)37 void ChangePrimitiveToAllToAllV(const AnfNodePtr &node) {
38 MS_EXCEPTION_IF_NULL(node);
39 auto neighbor_exchange = node->cast<CNodePtr>();
40 MS_EXCEPTION_IF_NULL(neighbor_exchange);
41
42 if (neighbor_exchange->size() == kCNodePrimitiveIdx) {
43 MS_LOG(INTERNAL_EXCEPTION) << "Inputs should not be empty for cnode " << node->DebugString()
44 << trace::DumpSourceLines(neighbor_exchange);
45 }
46
47 auto prim = GetValueNode<PrimitivePtr>(neighbor_exchange->input(kCNodePrimitiveIdx));
48 MS_EXCEPTION_IF_NULL(prim);
49 prim->Named::operator=(Named(kAllToAllvOpName));
50 }
51
GetRankSize(const std::string & group)52 uint32_t GetRankSize(const std::string &group) {
53 uint32_t rank_size;
54 if (!CommManager::GetInstance().GetRankSize(group, &rank_size)) {
55 MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed.";
56 }
57 return rank_size;
58 }
59 } // namespace
60
CreateSplitNode(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const AnfNodePtr & input_node,int64_t split_count,int64_t split_dim) const61 CNodePtr AllToAllUnifyMindIR::CreateSplitNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
62 const AnfNodePtr &input_node, int64_t split_count,
63 int64_t split_dim) const {
64 MS_EXCEPTION_IF_NULL(graph);
65 MS_EXCEPTION_IF_NULL(all_to_all);
66
67 std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
68 input_node, graph->NewValueNode(MakeValue(split_dim)),
69 graph->NewValueNode(MakeValue(split_count))};
70 auto split = NewCNode(split_input, graph);
71 MS_EXCEPTION_IF_NULL(split);
72 split->set_scope(all_to_all->scope());
73 auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
74 auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
75 auto shape_size = SizeToLong(shape.size());
76 if (split_dim >= shape_size || split_dim < -shape_size) {
77 MS_LOG(INTERNAL_EXCEPTION) << "Invalid split dim " << split_dim << " is over the shape size " << shape.size()
78 << trace::DumpSourceLines(all_to_all);
79 }
80 size_t split_idx = split_dim < 0 ? LongToSize(split_dim + shape_size) : LongToSize(split_dim);
81 if (shape[split_idx] >= 0 && (split_count == 0 || shape[split_idx] % split_count != 0)) {
82 MS_LOG(INTERNAL_EXCEPTION) << "Invalid split count " << split_count << " cannot be divisible by shape[" << split_idx
83 << "] = " << shape[split_idx] << trace::DumpSourceLines(all_to_all);
84 }
85 shape[split_idx] = shape[split_idx] >= 0 ? shape[split_idx] / split_count : shape[split_idx];
86 std::vector<TypeId> dtypes(split_count, dtype);
87 std::vector<ShapeVector> shapes(split_count, shape);
88 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
89
90 common::AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split);
91 return split;
92 }
93
CreateAllToAllvNode(const FuncGraphPtr & graph,const CNodePtr & neighbor_exchange) const94 CNodePtr NeighborExchangeUnifyMindIR::CreateAllToAllvNode(const FuncGraphPtr &graph,
95 const CNodePtr &neighbor_exchange) const {
96 std::string group = common::AnfAlgo::GetNodeAttr<std::string>(neighbor_exchange, kAttrGroup);
97 std::vector<uint32_t> group_rank_ids =
98 common::AnfAlgo::HasNodeAttr(kAttrGroupRankIds, neighbor_exchange)
99 ? common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(neighbor_exchange, kAttrGroupRankIds)
100 : std::vector<uint32_t>();
101 std::vector<int64_t> send_rank_ids =
102 common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange, kAttrSendRankIds);
103 std::vector<int64_t> recv_rank_ids =
104 common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange, kAttrRecvRankIds);
105
106 int64_t send_count = send_rank_ids.size(), recv_count = recv_rank_ids.size();
107 auto tuple_input = neighbor_exchange->input(1);
108 std::vector<AnfNodePtr> split_outputs;
109 CreateMultipleOutputsOfAnfNode(graph, tuple_input, static_cast<size_t>(send_count), &split_outputs);
110 if (split_outputs.empty()) {
111 MS_LOG(INTERNAL_EXCEPTION) << "The node " << tuple_input->DebugString()
112 << " should have at least one output, but got 0." << trace::DumpSourceLines(tuple_input);
113 }
114 std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllvOpName))};
115 (void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
116 auto all_to_all_v = NewCNode(all_to_all_v_input, graph);
117 MS_EXCEPTION_IF_NULL(all_to_all_v);
118
119 auto single_shape = AnfAlgo::GetOutputDetailShape(split_outputs[0], 0UL);
120 auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0UL);
121 std::vector<TypeId> dtypes(recv_count, single_type);
122 std::vector<BaseShapePtr> shapes(recv_count, single_shape);
123 common::AnfAlgo::SetSingleOutputTypeAndDetailShape(dtypes, shapes, all_to_all_v.get());
124
125 common::AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(send_rank_ids), all_to_all_v);
126 common::AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(recv_rank_ids), all_to_all_v);
127 common::AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), all_to_all_v);
128 common::AnfAlgo::SetNodeAttr(kAttrGroupRankIds, MakeValue<std::vector<uint32_t>>(group_rank_ids), all_to_all_v);
129
130 auto neighbor_exchange_prim = GetCNodePrimitive(neighbor_exchange);
131 MS_EXCEPTION_IF_NULL(neighbor_exchange_prim);
132 if (neighbor_exchange_prim->HasAttr(parallel::COMM_REUSE) &&
133 GetValue<bool>(neighbor_exchange_prim->GetAttr(parallel::COMM_REUSE))) {
134 auto all_to_all_v_prim = GetCNodePrimitive(all_to_all_v);
135 MS_EXCEPTION_IF_NULL(all_to_all_v_prim);
136 (void)all_to_all_v_prim->AddAttr(parallel::COMM_REUSE, MakeValue(true));
137 }
138
139 if (neighbor_exchange_prim->HasAttr("FLASH_INDEX")) {
140 auto flash_index = GetValue<std::string>(neighbor_exchange_prim->GetAttr("FLASH_INDEX"));
141 auto all_to_all_v_prim = GetCNodePrimitive(all_to_all_v);
142 MS_EXCEPTION_IF_NULL(all_to_all_v_prim);
143 (void)all_to_all_v_prim->AddAttr("FLASH_INDEX", MakeValue<std::string>(flash_index));
144 }
145 return all_to_all_v;
146 }
147
CreateSplitNodeWithSplitDim(const KernelGraphPtr & graph,const CNodePtr & all_to_all) const148 CNodePtr AllToAllUnifyMindIR::CreateSplitNodeWithSplitDim(const KernelGraphPtr &graph,
149 const CNodePtr &all_to_all) const {
150 MS_EXCEPTION_IF_NULL(all_to_all);
151 int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
152 int64_t split_dim = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitDim);
153
154 if (all_to_all->size() <= kAllToAllInputIdx) {
155 MS_LOG(EXCEPTION) << "Inputs should not be empty for cnode " << all_to_all->DebugString()
156 << trace::DumpSourceLines(all_to_all);
157 }
158 auto all_to_all_input = all_to_all->input(kAllToAllInputIdx);
159 return CreateSplitNode(graph, all_to_all, all_to_all_input, split_count, split_dim);
160 }
161
CreateSplitNodeWithDim0(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & input_node) const162 CNodePtr AllToAllUnifyMindIR::CreateSplitNodeWithDim0(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
163 const CNodePtr &input_node) const {
164 MS_EXCEPTION_IF_NULL(all_to_all);
165 int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
166 return CreateSplitNode(graph, all_to_all, input_node, split_count, 0);
167 }
168
CreateAllToAllvNode(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & split) const169 CNodePtr AllToAllUnifyMindIR::CreateAllToAllvNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
170 const CNodePtr &split) const {
171 MS_EXCEPTION_IF_NULL(graph);
172 MS_EXCEPTION_IF_NULL(all_to_all);
173 MS_EXCEPTION_IF_NULL(split);
174 int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
175 std::string group = common::AnfAlgo::GetNodeAttr<std::string>(all_to_all, kAttrGroup);
176 std::vector<uint32_t> group_rank_ids =
177 common::AnfAlgo::HasNodeAttr(kAttrGroupRankIds, all_to_all)
178 ? common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(all_to_all, kAttrGroupRankIds)
179 : std::vector<uint32_t>();
180 std::vector<AnfNodePtr> split_outputs;
181 CreateMultipleOutputsOfAnfNode(graph, split, static_cast<size_t>(split_count), &split_outputs);
182 if (split_outputs.empty()) {
183 MS_LOG(INTERNAL_EXCEPTION) << "The node " << split->DebugString() << " should have at least one output, but got 0."
184 << trace::DumpSourceLines(split);
185 }
186 std::vector<AnfNodePtr> new_ata_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllvOpName))};
187 (void)new_ata_input.insert(new_ata_input.end(), split_outputs.begin(), split_outputs.end());
188 auto new_ata = NewCNode(new_ata_input, graph);
189 MS_EXCEPTION_IF_NULL(new_ata);
190 new_ata->set_scope(all_to_all->scope());
191 auto single_shape = AnfAlgo::GetOutputDetailShape(split_outputs[0], 0UL);
192 auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0UL);
193 std::vector<TypeId> dtypes(split_count, single_type);
194 std::vector<BaseShapePtr> shapes(split_count, single_shape);
195 common::AnfAlgo::SetOutputTypeAndDetailShape(dtypes, shapes, new_ata.get());
196 uint32_t rank_size = GetRankSize(group);
197 std::vector<int64_t> rank_ids(rank_size, 0);
198 for (uint32_t i = 0; i < rank_size; ++i) {
199 rank_ids[i] = static_cast<int64_t>(i);
200 }
201
202 common::AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(rank_ids), new_ata);
203 common::AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(rank_ids), new_ata);
204 common::AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), new_ata);
205 common::AnfAlgo::SetNodeAttr(kAttrGroupRankIds, MakeValue<std::vector<uint32_t>>(group_rank_ids), new_ata);
206 auto all_to_all_prim = GetCNodePrimitive(all_to_all);
207 MS_EXCEPTION_IF_NULL(all_to_all_prim);
208 if (all_to_all_prim->HasAttr(parallel::COMM_REUSE) &&
209 GetValue<bool>(all_to_all_prim->GetAttr(parallel::COMM_REUSE))) {
210 auto new_ata_prim = GetCNodePrimitive(new_ata);
211 MS_EXCEPTION_IF_NULL(new_ata_prim);
212 (void)new_ata_prim->AddAttr(parallel::COMM_REUSE, MakeValue(true));
213 }
214 MS_LOG(INFO) << "Create AllToAllv success, split count " << split_count << ", rank size " << rank_size;
215 return new_ata;
216 }
217
CreateAllToAllNode(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & concat) const218 CNodePtr AllToAllUnifyMindIR::CreateAllToAllNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
219 const CNodePtr &concat) const {
220 MS_EXCEPTION_IF_NULL(graph);
221 MS_EXCEPTION_IF_NULL(all_to_all);
222 MS_EXCEPTION_IF_NULL(concat);
223 int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
224 std::string group = common::AnfAlgo::GetNodeAttr<std::string>(all_to_all, kAttrGroup);
225 std::vector<AnfNodePtr> new_ata_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllOpName))};
226 (void)new_ata_input.insert(new_ata_input.end(), concat);
227 auto new_ata = NewCNode(new_ata_input, graph);
228 MS_EXCEPTION_IF_NULL(new_ata);
229 new_ata->set_scope(all_to_all->scope());
230 new_ata->set_abstract(concat->abstract());
231 common::AnfAlgo::CopyNodeAttr(kAttrGroup, all_to_all, new_ata);
232 auto all_to_all_prim = GetCNodePrimitive(all_to_all);
233 MS_EXCEPTION_IF_NULL(all_to_all_prim);
234 if (all_to_all_prim->HasAttr(parallel::COMM_REUSE) &&
235 GetValue<bool>(all_to_all_prim->GetAttr(parallel::COMM_REUSE))) {
236 auto new_ata_prim = GetCNodePrimitive(new_ata);
237 MS_EXCEPTION_IF_NULL(new_ata_prim);
238 (void)new_ata_prim->AddAttr(parallel::COMM_REUSE, MakeValue(true));
239 }
240 common::AnfAlgo::SetNodeAttr(kAttrIrUnified, MakeValue(true), new_ata);
241 uint32_t rank_size = GetRankSize(group);
242 MS_LOG(INFO) << "Create AlltoAll success, split count " << split_count << ", rank size " << rank_size;
243 return new_ata;
244 }
245
CreateConcatNode(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & input_node,int64_t split_count,int64_t concat_dim) const246 CNodePtr AllToAllUnifyMindIR::CreateConcatNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
247 const CNodePtr &input_node, int64_t split_count,
248 int64_t concat_dim) const {
249 MS_EXCEPTION_IF_NULL(graph);
250 MS_EXCEPTION_IF_NULL(all_to_all);
251 MS_EXCEPTION_IF_NULL(input_node);
252 std::vector<AnfNodePtr> input_node_outputs;
253 CreateMultipleOutputsOfAnfNode(graph, input_node, static_cast<size_t>(split_count), &input_node_outputs);
254 if (input_node_outputs.empty()) {
255 MS_LOG(INTERNAL_EXCEPTION) << "The node " << input_node->DebugString()
256 << " should have at least one output, but got 0." << trace::DumpSourceLines(input_node);
257 }
258 std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName)), input_node,
259 graph->NewValueNode(MakeValue(concat_dim))};
260 auto concat = NewCNode(concat_input, graph);
261 MS_EXCEPTION_IF_NULL(concat);
262 concat->set_scope(all_to_all->scope());
263 auto single_shape = common::AnfAlgo::GetOutputInferShape(input_node_outputs[0], 0);
264 auto shape_size = SizeToLong(single_shape.size());
265 if (concat_dim >= shape_size || concat_dim < -shape_size) {
266 MS_LOG(INTERNAL_EXCEPTION) << "Invalid concat dim " << concat_dim << " is greater than shape size "
267 << single_shape.size() << trace::DumpSourceLines(all_to_all);
268 }
269 size_t concat_idx = concat_dim < 0 ? LongToSize(concat_dim + shape_size) : LongToSize(concat_dim);
270 single_shape[concat_idx] =
271 single_shape[concat_idx] >= 0 ? single_shape[concat_idx] * split_count : single_shape[concat_idx];
272 common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(input_node_outputs[0], 0UL)},
273 {single_shape}, concat.get());
274 return concat;
275 }
276
CreateConcatNodeWithConcatDim(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & input_node) const277 CNodePtr AllToAllUnifyMindIR::CreateConcatNodeWithConcatDim(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
278 const CNodePtr &input_node) const {
279 MS_EXCEPTION_IF_NULL(all_to_all);
280 int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
281 int64_t concat_dim = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrConcatDim);
282 return CreateConcatNode(graph, all_to_all, input_node, split_count, concat_dim);
283 }
284
CreateConcatNodeWithDim0(const KernelGraphPtr & graph,const CNodePtr & all_to_all,const CNodePtr & input_node) const285 CNodePtr AllToAllUnifyMindIR::CreateConcatNodeWithDim0(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
286 const CNodePtr &input_node) const {
287 MS_EXCEPTION_IF_NULL(all_to_all);
288 int64_t split_count = common::AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
289 return CreateConcatNode(graph, all_to_all, input_node, split_count, 0);
290 }
291
MustExistPrimitiveName() const292 std::vector<std::string> NeighborExchangeUnifyMindIR::MustExistPrimitiveName() const {
293 std::vector<std::string> ret;
294 ret.emplace_back(prim::kPrimNeighborExchange->name());
295 return ret;
296 }
297
DefinePattern() const298 const BaseRef NeighborExchangeUnifyMindIR::DefinePattern() const {
299 return VectorRef({prim::kPrimNeighborExchange, std::make_shared<SeqVar>()});
300 }
301
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const302 const AnfNodePtr NeighborExchangeUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
303 const EquivPtr &) const {
304 MS_EXCEPTION_IF_NULL(graph);
305 MS_EXCEPTION_IF_NULL(node);
306 auto neighbor_exchange = node->cast<CNodePtr>();
307 MS_EXCEPTION_IF_NULL(neighbor_exchange);
308 auto neighbor_exchange_prim = GetCNodePrimitive(neighbor_exchange);
309 MS_EXCEPTION_IF_NULL(neighbor_exchange_prim);
310 if (!neighbor_exchange_prim->HasAttr(kAttrFlashIndex)) {
311 ChangePrimitiveToAllToAllV(node);
312 return node;
313 }
314 auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange);
315 return all_to_all_v;
316 }
317
MustExistPrimitiveName() const318 std::vector<std::string> AllToAllUnifyMindIR::MustExistPrimitiveName() const {
319 std::vector<std::string> ret;
320 ret.emplace_back(prim::kPrimAlltoAll->name());
321 return ret;
322 }
323
DefinePattern() const324 const BaseRef AllToAllUnifyMindIR::DefinePattern() const {
325 return VectorRef({prim::kPrimAlltoAll, std::make_shared<SeqVar>()});
326 }
327
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const328 const AnfNodePtr AllToAllUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
329 const EquivPtr &) const {
330 MS_EXCEPTION_IF_NULL(graph);
331 MS_EXCEPTION_IF_NULL(node);
332 auto all_to_all = node->cast<CNodePtr>();
333 MS_EXCEPTION_IF_NULL(all_to_all);
334 if (GetBoolAttr(all_to_all, kAttrIrUnified)) {
335 return nullptr;
336 }
337 auto kernel_graph = graph->cast<KernelGraphPtr>();
338 MS_EXCEPTION_IF_NULL(kernel_graph);
339 auto ms_context = MsContext::GetInstance();
340 bool is_kbk = ms_context->IsKByKExecutorMode() || ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) == false;
341 AnfNodePtr ret_node = nullptr;
342 if (is_kbk) {
343 auto split = CreateSplitNodeWithSplitDim(kernel_graph, all_to_all);
344 auto concat_dim0 = CreateConcatNodeWithDim0(kernel_graph, all_to_all, split);
345 auto new_ata = CreateAllToAllNode(kernel_graph, all_to_all, concat_dim0);
346 auto split_dim0 = CreateSplitNodeWithDim0(kernel_graph, all_to_all, new_ata);
347 auto concat = CreateConcatNodeWithConcatDim(kernel_graph, all_to_all, split_dim0);
348 ret_node = concat;
349 } else {
350 auto split = CreateSplitNodeWithSplitDim(kernel_graph, all_to_all);
351 auto new_ata = CreateAllToAllvNode(kernel_graph, all_to_all, split);
352 auto concat = CreateConcatNodeWithConcatDim(kernel_graph, all_to_all, new_ata);
353 ret_node = concat;
354 }
355 return ret_node;
356 }
357 } // namespace opt
358 } // namespace mindspore
359