• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "ir/value.h"
25 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h"
26 #include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h"
27 #include "frontend/parallel/ops_info/operator_info.h"
28 
29 namespace mindspore {
30 namespace parallel {
MakeTensor(int64_t n,int64_t c,int64_t h,int64_t w)31 const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w) {
32   TensorParam new_tensor;
33   new_tensor.tensor_type = kFloat32;
34   new_tensor.tensor_shape.shape_n = n;
35   new_tensor.tensor_shape.shape_c = c;
36   new_tensor.tensor_shape.shape_h = h;
37   new_tensor.tensor_shape.shape_w = w;
38   const TensorParam &tensor = new_tensor;
39   return tensor;
40 }
41 
MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t iter_ops)42 Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops) {
43   Graph::NodeType NewOp;
44   NewOp.name = ops[iter_ops]->name();
45   NewOp.info = InfoType::kApplication;
46 
47   auto pos = ops[iter_ops]->name().find("Info");
48   auto name = ops[iter_ops]->name().substr(0, pos);
49   auto op_type = ops[iter_ops]->type();
50   auto idx = DictOpType.find(op_type);
51   if (idx != DictOpType.end()) {
52     NewOp.apply.op_type = DictOpType.at(op_type);
53   } else if (name == STAND_ALONE) {
54     MS_LOG(INFO) << ops[iter_ops]->type() << ": standalone operator.";
55     NewOp.apply.op_type = OperatorType::kRecStandAlone;
56   } else if (name == BATCH_PARALLEL) {
57     MS_LOG(INFO) << ops[iter_ops]->type() << ": batch parallel operator.";
58     NewOp.apply.op_type = OperatorType::kRecBatchParallel;
59   } else {
60     NewOp.apply.op_type = OperatorType::kRecUnknownType;
61     MS_LOG(INFO) << ops[iter_ops]->name() << ": Unknown operator type " << op_type;
62   }
63 
64   if (ops[iter_ops]->outputs_shape().size() == SIZE_ZERO) {
65     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " outputs shape is empty.";
66   }
67 
68   if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_FOUR) {
69     NewOp.tensor_parm = MakeTensor(ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1],
70                                    ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_TWO],
71                                    ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_THREE]);
72   } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_THREE) {
73     NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1],
74                                    ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_TWO]);
75   } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_TWO) {
76     NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1]);
77   } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ONE) {
78     NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_shape()[0][0]);
79   } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ZERO) {
80     NewOp.tensor_parm = MakeTensor(1, 1, 1, 1);
81   } else {
82     MS_LOG(WARNING) << ops[iter_ops]->name() << ": output tensor shape is unexpected.";
83   }
84 
85   CompleteOperatorInputs(ops, iter_ops, &NewOp);
86   MS_LOG(INFO) << "Node " << NewOp.name << "created successfully"
87                << " its input is " << ops[iter_ops]->inputs_shape() << " and its output is "
88                << ops[iter_ops]->outputs_shape() << ".";
89   return NewOp;
90 }
91 
CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Graph::NodeType * NewTensor)92 void CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
93                             Graph::NodeType *NewTensor) {
94   size_t input_tensor_size = ops[iter_ops]->inputs_shape().size();
95   if (ops[iter_ops]->type() == STACK) {
96     input_tensor_size = 1;
97   }
98   if (input_tensor_size > MAX_INPUT_NUM) {
99     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor " << input_tensor_size << " num exceeds limit("
100                       << MAX_INPUT_NUM << ").";
101   }
102 
103   for (size_t iter_input_tensors = 0; iter_input_tensors < input_tensor_size; iter_input_tensors++) {
104     if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_FOUR) {
105       Complete4DInputs(ops, iter_ops, iter_input_tensors, NewTensor);
106     } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_THREE) {
107       NewTensor->apply.arguments[iter_input_tensors] =
108         MakeTensor(1, ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
109                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
110                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]);
111     } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_TWO) {
112       Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor);
113     } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_ONE) {
114       NewTensor->apply.arguments[iter_input_tensors] =
115         MakeTensor(1, 1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO]);
116     } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == 0) {
117       NewTensor->apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1);
118     } else {
119       MS_LOG(WARNING) << ops[iter_ops]->name() << ": input tensor shape is unexpected.";
120     }
121   }
122 }
123 
Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_input_tensors,Graph::NodeType * NewTensor)124 void Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
125                       const size_t iter_input_tensors, Graph::NodeType *NewTensor) {
126   if (NewTensor->apply.op_type == OperatorType::kRecMatMul) {
127     auto input_value = ops[iter_ops]->input_value();
128     bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
129     bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
130     if (transpose_a && (iter_input_tensors == 0)) {
131       NewTensor->apply.arguments[iter_input_tensors] =
132         MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][1],
133                    ops[iter_ops]->inputs_shape()[iter_input_tensors][0]);
134     } else if (transpose_b && (iter_input_tensors == 1)) {
135       NewTensor->apply.arguments[iter_input_tensors] =
136         MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][1],
137                    ops[iter_ops]->inputs_shape()[iter_input_tensors][0]);
138     } else {
139       NewTensor->apply.arguments[iter_input_tensors] =
140         MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][0],
141                    ops[iter_ops]->inputs_shape()[iter_input_tensors][1]);
142     }
143   } else {
144     NewTensor->apply.arguments[iter_input_tensors] = MakeTensor(
145       1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][0], ops[iter_ops]->inputs_shape()[iter_input_tensors][1]);
146   }
147 }
148 
Complete4DInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_input_tensors,Graph::NodeType * NewTensor)149 void Complete4DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
150                       const size_t iter_input_tensors, Graph::NodeType *NewTensor) {
151   if (NewTensor->apply.op_type == OperatorType::kRecBatchMatMul) {
152     auto input_value = ops[iter_ops]->input_value();
153     bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
154     bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
155     if (transpose_a && (iter_input_tensors == 0)) {
156       NewTensor->apply.arguments[iter_input_tensors] =
157         MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
158                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
159                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE],
160                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]);
161     } else if (transpose_b && (iter_input_tensors == 1)) {
162       NewTensor->apply.arguments[iter_input_tensors] =
163         MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
164                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
165                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE],
166                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]);
167     } else {
168       NewTensor->apply.arguments[iter_input_tensors] =
169         MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
170                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
171                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO],
172                    ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE]);
173     }
174   } else {
175     NewTensor->apply.arguments[iter_input_tensors] =
176       MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
177                  ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
178                  ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO],
179                  ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE]);
180   }
181 }
182 
ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names)183 std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
184                                   const std::vector<std::vector<std::string>> &input_tensor_names) {
185   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
186   constexpr size_t MAX_OP_NUM = SIZE_MAX / 2;
187   if (ops.size() > MAX_OP_NUM) {
188     MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << MAX_OP_NUM;
189   }
190 
191   for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) {
192     Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops);
193     NewOp.param_name = ops[iter_ops]->get_involved_param_name();
194     graph->nodes.push_back(NewOp);
195   }
196   MakeEdge(input_tensor_names, graph);
197 
198   return graph;
199 }
200 
MakeEdge(const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<Graph> & graph)201 void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph) {
202   for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) {
203     for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) {
204       size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]);
205       if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) {
206         graph->nodes[iter_i].node_in.push_back(head_node_index);
207         graph->nodes[head_node_index].node_out.push_back(iter_i);
208       }
209     }
210   }
211 }
212 
GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> & input_tensor_name,const std::string & input_name)213 size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_name,
214                                   const std::string &input_name) {
215   for (size_t index = 0; index < input_tensor_name.size(); index++) {
216     if (input_tensor_name[index][0] == input_name) {
217       return index;
218     }
219   }
220   MS_LOG(INFO) << "Get index failed, using SIZE_MAX instead";
221   return SIZE_MAX;
222 }
223 
Eliminate_Aux(size_t node_index,const std::shared_ptr<Graph> & graph,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list)224 void Eliminate_Aux(size_t node_index, const std::shared_ptr<Graph> &graph,
225                    const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list) {
226   MS_EXCEPTION_IF_NULL(graph);
227   MS_EXCEPTION_IF_NULL(eli_list);
228   std::vector<size_t> eli;
229   eli.push_back(node_index);
230   for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) {
231     auto outgoing_node_idx = graph->nodes[node_index].node_out[i];
232     eli.push_back(outgoing_node_idx);
233     if (!graph->nodes[node_index].param_name.empty() &&
234         graph->nodes[node_index].apply.op_type == OperatorType::kRecCast &&
235         (graph->nodes[outgoing_node_idx].apply.op_type == OperatorType::kRecMatMul ||
236          graph->nodes[outgoing_node_idx].apply.op_type == OperatorType::kRecBatchMatMul)) {
237       graph->nodes[outgoing_node_idx].param_name = graph->nodes[node_index].param_name;
238     }
239   }
240   eli_list->push_back(eli);
241 
242   // Iterate over all input operators of the current node
243   for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) {
244     auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out;
245     auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index);
246     if (it != incoming_outputs->end()) {
247       it = incoming_outputs->erase(it);
248       for (auto outgoing_index : graph->nodes[node_index].node_out) {
249         it = find(incoming_outputs->begin(), incoming_outputs->end(), outgoing_index);
250         if (it == incoming_outputs->end()) {
251           incoming_outputs->push_back(outgoing_index);
252         }
253       }
254     }
255   }
256 
257   // Iterate over all aux_input operators of the current node
258   for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) {
259     auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out;
260     auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index);
261     if (it != aux_incoming_outputs->end()) {
262       it = aux_incoming_outputs->erase(it);
263       for (auto outgoing_index : graph->nodes[node_index].node_out) {
264         it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), outgoing_index);
265         if (it == aux_incoming_outputs->end()) {
266           aux_incoming_outputs->push_back(outgoing_index);
267         }
268       }
269     }
270   }
271 
272   // Iterate over all output operators of the current node
273   Eliminate_Aux_Outgoing(node_index, graph);
274 }
275 
EliminateAuxOutgoingInput(size_t node_index,const std::shared_ptr<Graph> & graph,size_t i)276 void EliminateAuxOutgoingInput(size_t node_index, const std::shared_ptr<Graph> &graph, size_t i) {
277   MS_EXCEPTION_IF_NULL(graph);
278   auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in;
279   MS_EXCEPTION_IF_NULL(outgoing_inputs);
280   // Check if the current node is the input operator of the current node's output operator
281   auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index);
282   if (it != outgoing_inputs->end()) {
283     if (graph->nodes[node_index].node_in.size() > 0) {
284       // If the current node has input operator, then add input[0] of the current node to the input of the current
285       // node's output operator (if input[0] is also in the aux_input of the current node's output operator, then remove
286       // it from the aux_input and keep it only in the input)
287       auto exist_in_outgoing_auxinputs =
288         find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(),
289              graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(), graph->nodes[node_index].node_in[0]);
290       if (exist_in_outgoing_auxinputs != graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) {
291         size_t index_remove_node = LongToSize(std::distance(
292           graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), exist_in_outgoing_auxinputs));
293         if (graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.size() > index_remove_node) {
294           (void)graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.erase(
295             graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.begin() + index_remove_node);
296         } else {
297           MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_remove_node << ", out of range!";
298         }
299         if (graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.size() > index_remove_node) {
300           (void)graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.erase(exist_in_outgoing_auxinputs);
301         } else {
302           MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_remove_node
303                         << ", which is out of range!";
304         }
305       }
306       size_t idx = LongToSize(std::distance(outgoing_inputs->begin(), it));
307       if (outgoing_inputs->size() > idx) {
308         outgoing_inputs->at(idx) = graph->nodes[node_index].node_in[0];
309       } else {
310         MS_LOG(DEBUG) << "Trying to index vector element at index " << idx << ", out of range!";
311       }
312       // Then add the other input operators of the current node to the aux_input of the current node's output operator
313       for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) {
314         exist_in_outgoing_auxinputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(),
315                                            graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(),
316                                            graph->nodes[node_index].node_in[j]);
317         if (exist_in_outgoing_auxinputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) {
318           size_t index_aux = LongToSize(std::distance(outgoing_inputs->begin(), it));
319           graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.push_back(index_aux);
320           graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]);
321         }
322       }
323       // Then add all the operators in the aux_input of the current node to the aux_input of the output operator of the
324       // current node
325       for (size_t j = 0; j < graph->nodes[node_index].node_in_aux.size(); j++) {
326         exist_in_outgoing_auxinputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(),
327                                            graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(),
328                                            graph->nodes[node_index].node_in_aux[j]);
329         if (exist_in_outgoing_auxinputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) {
330           size_t index_aux = LongToSize(std::distance(outgoing_inputs->begin(), it));
331           graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.push_back(index_aux);
332           graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(
333             graph->nodes[node_index].node_in_aux[j]);
334         }
335       }
336     } else {
337       auto idx = LongToSize(std::distance(outgoing_inputs->begin(), it));
338       if (outgoing_inputs->size() > idx) {
339         (void)outgoing_inputs->erase(it);
340       } else {
341         MS_LOG(DEBUG) << "Trying to erase vector element at index " << idx << ", out of range!";
342       }
343     }
344   }
345 }
346 
EliminateAuxOutgoingAuxInput(size_t node_index,const std::shared_ptr<Graph> & graph,size_t i)347 void EliminateAuxOutgoingAuxInput(size_t node_index, const std::shared_ptr<Graph> &graph, size_t i) {
348   MS_EXCEPTION_IF_NULL(graph);
349   auto *outgoing_auxinputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux;
350   MS_EXCEPTION_IF_NULL(outgoing_auxinputs);
351   auto *outgoing_auxinputs_index = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx;
352   // Check if the current node is the aux_input operator of the current node's output operator
353   auto it = find(outgoing_auxinputs->begin(), outgoing_auxinputs->end(), node_index);
354   size_t index_entree = LongToSize(std::distance(outgoing_auxinputs->begin(), it));
355   if (it != outgoing_auxinputs->end()) {
356     if (graph->nodes[node_index].node_in.size() > 0) {
357       // If the current node has input operator, and if the input[0] of the current node is in
358       // the input of the output operator of the current node, then delete it
359       // from the aux_input of the output of the current node, otherwise add the input[0]
360       // to the auxinput of the output of the current node
361       auto exist_in_outgoing_inputs =
362         find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(),
363              graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), graph->nodes[node_index].node_in[0]);
364       if (exist_in_outgoing_inputs != graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) {
365         index_entree = LongToSize(std::distance(outgoing_auxinputs->begin(), it));
366         if (outgoing_auxinputs_index->size() > index_entree) {
367           (void)outgoing_auxinputs_index->erase(outgoing_auxinputs_index->begin() + index_entree);
368         } else {
369           MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!";
370         }
371         if (outgoing_auxinputs->size() > index_entree) {
372           (void)outgoing_auxinputs->erase(it);
373         } else {
374           MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!";
375         }
376       } else {
377         size_t idx = LongToSize(std::distance(outgoing_auxinputs->begin(), it));
378         if (outgoing_auxinputs->size() > idx) {
379           outgoing_auxinputs->at(idx) = graph->nodes[node_index].node_in[0];
380         } else {
381           MS_LOG(DEBUG) << "Trying to index vector element at index " << idx << ", out of range!";
382         }
383         index_entree = LongToSize(std::distance(
384           outgoing_auxinputs->begin(),
385           find(outgoing_auxinputs->begin(), outgoing_auxinputs->end(), graph->nodes[node_index].node_in[0])));
386       }
387       // Determine whether the other input operator of the current node is in the input of the output operator,
388       // and if not, add it to the aux_input of the output operator
389       for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) {
390         exist_in_outgoing_inputs =
391           find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(),
392                graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), graph->nodes[node_index].node_in[j]);
393         if (exist_in_outgoing_inputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) {
394           outgoing_auxinputs->push_back(graph->nodes[node_index].node_in[j]);
395           if (outgoing_auxinputs_index->size() > index_entree) {
396             outgoing_auxinputs_index->push_back(outgoing_auxinputs_index->at(index_entree));
397           } else {
398             MS_LOG(DEBUG) << "Trying to index vector element at index " << index_entree << ", out of range!";
399           }
400         }
401       }
402       // Determine if the aux_input operator of the current node is in the input of the output operator,
403       // and if not, add it to the aux_input of the output operator
404       for (size_t j = 0; j < graph->nodes[node_index].node_in_aux.size(); j++) {
405         exist_in_outgoing_inputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(),
406                                         graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(),
407                                         graph->nodes[node_index].node_in_aux[j]);
408         if (exist_in_outgoing_inputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) {
409           outgoing_auxinputs->push_back(graph->nodes[node_index].node_in_aux[j]);
410           outgoing_auxinputs_index->push_back(outgoing_auxinputs_index->at(index_entree));
411         }
412       }
413     } else {
414       if (outgoing_auxinputs_index->size() > index_entree) {
415         (void)outgoing_auxinputs_index->erase(outgoing_auxinputs_index->begin() + index_entree);
416       } else {
417         MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!";
418       }
419       if (outgoing_auxinputs->size() > index_entree) {
420         (void)outgoing_auxinputs->erase(it);
421       } else {
422         MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", which is out of range.";
423       }
424     }
425   }
426 }
427 
Eliminate_Aux_Outgoing(size_t node_index,const std::shared_ptr<Graph> & graph)428 void Eliminate_Aux_Outgoing(size_t node_index, const std::shared_ptr<Graph> &graph) {
429   for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) {
430     // Handle the output operator connected to the current node via main edge
431     EliminateAuxOutgoingInput(node_index, graph, i);
432     // Handle the output operator connected to the current node via auxiliary edge
433     EliminateAuxOutgoingAuxInput(node_index, graph, i);
434   }
435 }
436 
EraseEliminatedNode(std::vector<size_t> * nodes,const std::shared_ptr<std::vector<size_t>> & index_list)437 static void EraseEliminatedNode(std::vector<size_t> *nodes, const std::shared_ptr<std::vector<size_t>> &index_list) {
438   for (size_t j = nodes->size(); j > 0; j--) {
439     bool IsEliminated = (index_list->at(nodes->at(j - 1)) == SIZE_MAX);
440     if (IsEliminated) {
441       (void)nodes->erase(nodes->begin() + SizeToLong(j) - 1);
442     } else {
443       nodes->at(j - 1) = index_list->at(nodes->at(j - 1));
444     }
445   }
446 }
447 
EliminateGraph(const std::shared_ptr<Graph> & graph,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::shared_ptr<std::vector<size_t>> & index_list,const bool dyn_shape_tmp_fix)448 std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
449                                       const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
450                                       const std::shared_ptr<std::vector<size_t>> &index_list,
451                                       const bool dyn_shape_tmp_fix) {
452   MS_EXCEPTION_IF_NULL(graph);
453   for (size_t node_index = 0; node_index < graph->nodes.size(); node_index++) {
454     auto type = graph->nodes[node_index].apply.op_type;
455     if (dyn_shape_tmp_fix && type == OperatorType::kRecBatchMatMul) {
456       continue;
457     } else if (EliminateOpType.find(type) != EliminateOpType.end()) {
458       Eliminate_Aux(node_index, graph, eli_list);
459     }
460   }
461   index_list->reserve(graph->nodes.size());
462   for (size_t i = 0; i < graph->nodes.size(); i++) {
463     index_list->push_back(i);
464   }
465   for (size_t i = 0; i < eli_list->size(); i++) {
466     if (eli_list->at(i)[0] >= index_list->size()) {
467       MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
468     }
469     index_list->at(eli_list->at(i)[0]) = SIZE_MAX;
470     for (size_t j = eli_list->at(i)[0] + 1; j < index_list->size(); j++) {
471       index_list->at(j)--;
472     }
473   }
474   std::shared_ptr<Graph> new_graph = std::make_shared<Graph>();
475   for (size_t i = 0; i < graph->nodes.size(); i++) {
476     if (index_list->at(i) > SIZE_MAX / 2) {
477       continue;
478     }
479     new_graph->nodes.push_back(graph->nodes[i]);
480     auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
481     EraseEliminatedNode(node_in, index_list);
482     auto *node_in_aux = &new_graph->nodes[index_list->at(i)].node_in_aux;
483     EraseEliminatedNode(node_in_aux, index_list);
484     auto *node_out = &new_graph->nodes[index_list->at(i)].node_out;
485     EraseEliminatedNode(node_out, index_list);
486   }
487   return new_graph;
488 }
489 }  // namespace parallel
490 }  // namespace mindspore
491