1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "ir/value.h"
25 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h"
26 #include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h"
27 #include "frontend/parallel/ops_info/operator_info.h"
28
29 namespace mindspore {
30 namespace parallel {
MakeTensor(int64_t n,int64_t c,int64_t h,int64_t w)31 const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w) {
32 TensorParam new_tensor;
33 new_tensor.tensor_type = kFloat32;
34 new_tensor.tensor_shape.shape_n = n;
35 new_tensor.tensor_shape.shape_c = c;
36 new_tensor.tensor_shape.shape_h = h;
37 new_tensor.tensor_shape.shape_w = w;
38 const TensorParam &tensor = new_tensor;
39 return tensor;
40 }
41
MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t iter_ops)42 Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops) {
43 Graph::NodeType NewOp;
44 NewOp.name = ops[iter_ops]->name();
45 NewOp.info = InfoType::kApplication;
46
47 auto pos = ops[iter_ops]->name().find("Info");
48 auto name = ops[iter_ops]->name().substr(0, pos);
49 auto op_type = ops[iter_ops]->type();
50 auto idx = DictOpType.find(op_type);
51 if (idx != DictOpType.end()) {
52 NewOp.apply.op_type = DictOpType.at(op_type);
53 } else if (name == STAND_ALONE) {
54 MS_LOG(INFO) << ops[iter_ops]->type() << ": standalone operator.";
55 NewOp.apply.op_type = OperatorType::kRecStandAlone;
56 } else if (name == BATCH_PARALLEL) {
57 MS_LOG(INFO) << ops[iter_ops]->type() << ": batch parallel operator.";
58 NewOp.apply.op_type = OperatorType::kRecBatchParallel;
59 } else {
60 NewOp.apply.op_type = OperatorType::kRecUnknownType;
61 MS_LOG(INFO) << ops[iter_ops]->name() << ": Unknown operator type " << op_type;
62 }
63
64 if (ops[iter_ops]->outputs_shape().size() == SIZE_ZERO) {
65 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " outputs shape is empty.";
66 }
67
68 if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_FOUR) {
69 NewOp.tensor_parm = MakeTensor(ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1],
70 ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_TWO],
71 ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_THREE]);
72 } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_THREE) {
73 NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1],
74 ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_TWO]);
75 } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_TWO) {
76 NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1]);
77 } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ONE) {
78 NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_shape()[0][0]);
79 } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ZERO) {
80 NewOp.tensor_parm = MakeTensor(1, 1, 1, 1);
81 } else {
82 MS_LOG(WARNING) << ops[iter_ops]->name() << ": output tensor shape is unexpected.";
83 }
84
85 CompleteOperatorInputs(ops, iter_ops, &NewOp);
86 MS_LOG(INFO) << "Node " << NewOp.name << "created successfully"
87 << " its input is " << ops[iter_ops]->inputs_shape() << " and its output is "
88 << ops[iter_ops]->outputs_shape() << ".";
89 return NewOp;
90 }
91
CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Graph::NodeType * NewTensor)92 void CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
93 Graph::NodeType *NewTensor) {
94 size_t input_tensor_size = ops[iter_ops]->inputs_shape().size();
95 if (ops[iter_ops]->type() == STACK) {
96 input_tensor_size = 1;
97 }
98 if (input_tensor_size > MAX_INPUT_NUM) {
99 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor " << input_tensor_size << " num exceeds limit("
100 << MAX_INPUT_NUM << ").";
101 }
102
103 for (size_t iter_input_tensors = 0; iter_input_tensors < input_tensor_size; iter_input_tensors++) {
104 if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_FOUR) {
105 Complete4DInputs(ops, iter_ops, iter_input_tensors, NewTensor);
106 } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_THREE) {
107 NewTensor->apply.arguments[iter_input_tensors] =
108 MakeTensor(1, ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
109 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
110 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]);
111 } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_TWO) {
112 Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor);
113 } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_ONE) {
114 NewTensor->apply.arguments[iter_input_tensors] =
115 MakeTensor(1, 1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO]);
116 } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == 0) {
117 NewTensor->apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1);
118 } else {
119 MS_LOG(WARNING) << ops[iter_ops]->name() << ": input tensor shape is unexpected.";
120 }
121 }
122 }
123
Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_input_tensors,Graph::NodeType * NewTensor)124 void Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
125 const size_t iter_input_tensors, Graph::NodeType *NewTensor) {
126 if (NewTensor->apply.op_type == OperatorType::kRecMatMul) {
127 auto input_value = ops[iter_ops]->input_value();
128 bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
129 bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
130 if (transpose_a && (iter_input_tensors == 0)) {
131 NewTensor->apply.arguments[iter_input_tensors] =
132 MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][1],
133 ops[iter_ops]->inputs_shape()[iter_input_tensors][0]);
134 } else if (transpose_b && (iter_input_tensors == 1)) {
135 NewTensor->apply.arguments[iter_input_tensors] =
136 MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][1],
137 ops[iter_ops]->inputs_shape()[iter_input_tensors][0]);
138 } else {
139 NewTensor->apply.arguments[iter_input_tensors] =
140 MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][0],
141 ops[iter_ops]->inputs_shape()[iter_input_tensors][1]);
142 }
143 } else {
144 NewTensor->apply.arguments[iter_input_tensors] = MakeTensor(
145 1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][0], ops[iter_ops]->inputs_shape()[iter_input_tensors][1]);
146 }
147 }
148
Complete4DInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_input_tensors,Graph::NodeType * NewTensor)149 void Complete4DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
150 const size_t iter_input_tensors, Graph::NodeType *NewTensor) {
151 if (NewTensor->apply.op_type == OperatorType::kRecBatchMatMul) {
152 auto input_value = ops[iter_ops]->input_value();
153 bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
154 bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
155 if (transpose_a && (iter_input_tensors == 0)) {
156 NewTensor->apply.arguments[iter_input_tensors] =
157 MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
158 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
159 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE],
160 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]);
161 } else if (transpose_b && (iter_input_tensors == 1)) {
162 NewTensor->apply.arguments[iter_input_tensors] =
163 MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
164 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
165 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE],
166 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]);
167 } else {
168 NewTensor->apply.arguments[iter_input_tensors] =
169 MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
170 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
171 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO],
172 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE]);
173 }
174 } else {
175 NewTensor->apply.arguments[iter_input_tensors] =
176 MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO],
177 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE],
178 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO],
179 ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE]);
180 }
181 }
182
ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names)183 std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
184 const std::vector<std::vector<std::string>> &input_tensor_names) {
185 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
186 constexpr size_t MAX_OP_NUM = SIZE_MAX / 2;
187 if (ops.size() > MAX_OP_NUM) {
188 MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << MAX_OP_NUM;
189 }
190
191 for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) {
192 Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops);
193 NewOp.param_name = ops[iter_ops]->get_involved_param_name();
194 graph->nodes.push_back(NewOp);
195 }
196 MakeEdge(input_tensor_names, graph);
197
198 return graph;
199 }
200
MakeEdge(const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<Graph> & graph)201 void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph) {
202 for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) {
203 for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) {
204 size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]);
205 if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) {
206 graph->nodes[iter_i].node_in.push_back(head_node_index);
207 graph->nodes[head_node_index].node_out.push_back(iter_i);
208 }
209 }
210 }
211 }
212
GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> & input_tensor_name,const std::string & input_name)213 size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_name,
214 const std::string &input_name) {
215 for (size_t index = 0; index < input_tensor_name.size(); index++) {
216 if (input_tensor_name[index][0] == input_name) {
217 return index;
218 }
219 }
220 MS_LOG(INFO) << "Get index failed, using SIZE_MAX instead";
221 return SIZE_MAX;
222 }
223
Eliminate_Aux(size_t node_index,const std::shared_ptr<Graph> & graph,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list)224 void Eliminate_Aux(size_t node_index, const std::shared_ptr<Graph> &graph,
225 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list) {
226 MS_EXCEPTION_IF_NULL(graph);
227 MS_EXCEPTION_IF_NULL(eli_list);
228 std::vector<size_t> eli;
229 eli.push_back(node_index);
230 for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) {
231 auto outgoing_node_idx = graph->nodes[node_index].node_out[i];
232 eli.push_back(outgoing_node_idx);
233 if (!graph->nodes[node_index].param_name.empty() &&
234 graph->nodes[node_index].apply.op_type == OperatorType::kRecCast &&
235 (graph->nodes[outgoing_node_idx].apply.op_type == OperatorType::kRecMatMul ||
236 graph->nodes[outgoing_node_idx].apply.op_type == OperatorType::kRecBatchMatMul)) {
237 graph->nodes[outgoing_node_idx].param_name = graph->nodes[node_index].param_name;
238 }
239 }
240 eli_list->push_back(eli);
241
242 // Iterate over all input operators of the current node
243 for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) {
244 auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out;
245 auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index);
246 if (it != incoming_outputs->end()) {
247 it = incoming_outputs->erase(it);
248 for (auto outgoing_index : graph->nodes[node_index].node_out) {
249 it = find(incoming_outputs->begin(), incoming_outputs->end(), outgoing_index);
250 if (it == incoming_outputs->end()) {
251 incoming_outputs->push_back(outgoing_index);
252 }
253 }
254 }
255 }
256
257 // Iterate over all aux_input operators of the current node
258 for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) {
259 auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out;
260 auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index);
261 if (it != aux_incoming_outputs->end()) {
262 it = aux_incoming_outputs->erase(it);
263 for (auto outgoing_index : graph->nodes[node_index].node_out) {
264 it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), outgoing_index);
265 if (it == aux_incoming_outputs->end()) {
266 aux_incoming_outputs->push_back(outgoing_index);
267 }
268 }
269 }
270 }
271
272 // Iterate over all output operators of the current node
273 Eliminate_Aux_Outgoing(node_index, graph);
274 }
275
EliminateAuxOutgoingInput(size_t node_index,const std::shared_ptr<Graph> & graph,size_t i)276 void EliminateAuxOutgoingInput(size_t node_index, const std::shared_ptr<Graph> &graph, size_t i) {
277 MS_EXCEPTION_IF_NULL(graph);
278 auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in;
279 MS_EXCEPTION_IF_NULL(outgoing_inputs);
280 // Check if the current node is the input operator of the current node's output operator
281 auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index);
282 if (it != outgoing_inputs->end()) {
283 if (graph->nodes[node_index].node_in.size() > 0) {
284 // If the current node has input operator, then add input[0] of the current node to the input of the current
285 // node's output operator (if input[0] is also in the aux_input of the current node's output operator, then remove
286 // it from the aux_input and keep it only in the input)
287 auto exist_in_outgoing_auxinputs =
288 find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(),
289 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(), graph->nodes[node_index].node_in[0]);
290 if (exist_in_outgoing_auxinputs != graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) {
291 size_t index_remove_node = LongToSize(std::distance(
292 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), exist_in_outgoing_auxinputs));
293 if (graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.size() > index_remove_node) {
294 (void)graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.erase(
295 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.begin() + index_remove_node);
296 } else {
297 MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_remove_node << ", out of range!";
298 }
299 if (graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.size() > index_remove_node) {
300 (void)graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.erase(exist_in_outgoing_auxinputs);
301 } else {
302 MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_remove_node
303 << ", which is out of range!";
304 }
305 }
306 size_t idx = LongToSize(std::distance(outgoing_inputs->begin(), it));
307 if (outgoing_inputs->size() > idx) {
308 outgoing_inputs->at(idx) = graph->nodes[node_index].node_in[0];
309 } else {
310 MS_LOG(DEBUG) << "Trying to index vector element at index " << idx << ", out of range!";
311 }
312 // Then add the other input operators of the current node to the aux_input of the current node's output operator
313 for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) {
314 exist_in_outgoing_auxinputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(),
315 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(),
316 graph->nodes[node_index].node_in[j]);
317 if (exist_in_outgoing_auxinputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) {
318 size_t index_aux = LongToSize(std::distance(outgoing_inputs->begin(), it));
319 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.push_back(index_aux);
320 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]);
321 }
322 }
323 // Then add all the operators in the aux_input of the current node to the aux_input of the output operator of the
324 // current node
325 for (size_t j = 0; j < graph->nodes[node_index].node_in_aux.size(); j++) {
326 exist_in_outgoing_auxinputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(),
327 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(),
328 graph->nodes[node_index].node_in_aux[j]);
329 if (exist_in_outgoing_auxinputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) {
330 size_t index_aux = LongToSize(std::distance(outgoing_inputs->begin(), it));
331 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.push_back(index_aux);
332 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(
333 graph->nodes[node_index].node_in_aux[j]);
334 }
335 }
336 } else {
337 auto idx = LongToSize(std::distance(outgoing_inputs->begin(), it));
338 if (outgoing_inputs->size() > idx) {
339 (void)outgoing_inputs->erase(it);
340 } else {
341 MS_LOG(DEBUG) << "Trying to erase vector element at index " << idx << ", out of range!";
342 }
343 }
344 }
345 }
346
EliminateAuxOutgoingAuxInput(size_t node_index,const std::shared_ptr<Graph> & graph,size_t i)347 void EliminateAuxOutgoingAuxInput(size_t node_index, const std::shared_ptr<Graph> &graph, size_t i) {
348 MS_EXCEPTION_IF_NULL(graph);
349 auto *outgoing_auxinputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux;
350 MS_EXCEPTION_IF_NULL(outgoing_auxinputs);
351 auto *outgoing_auxinputs_index = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx;
352 // Check if the current node is the aux_input operator of the current node's output operator
353 auto it = find(outgoing_auxinputs->begin(), outgoing_auxinputs->end(), node_index);
354 size_t index_entree = LongToSize(std::distance(outgoing_auxinputs->begin(), it));
355 if (it != outgoing_auxinputs->end()) {
356 if (graph->nodes[node_index].node_in.size() > 0) {
357 // If the current node has input operator, and if the input[0] of the current node is in
358 // the input of the output operator of the current node, then delete it
359 // from the aux_input of the output of the current node, otherwise add the input[0]
360 // to the auxinput of the output of the current node
361 auto exist_in_outgoing_inputs =
362 find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(),
363 graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), graph->nodes[node_index].node_in[0]);
364 if (exist_in_outgoing_inputs != graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) {
365 index_entree = LongToSize(std::distance(outgoing_auxinputs->begin(), it));
366 if (outgoing_auxinputs_index->size() > index_entree) {
367 (void)outgoing_auxinputs_index->erase(outgoing_auxinputs_index->begin() + index_entree);
368 } else {
369 MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!";
370 }
371 if (outgoing_auxinputs->size() > index_entree) {
372 (void)outgoing_auxinputs->erase(it);
373 } else {
374 MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!";
375 }
376 } else {
377 size_t idx = LongToSize(std::distance(outgoing_auxinputs->begin(), it));
378 if (outgoing_auxinputs->size() > idx) {
379 outgoing_auxinputs->at(idx) = graph->nodes[node_index].node_in[0];
380 } else {
381 MS_LOG(DEBUG) << "Trying to index vector element at index " << idx << ", out of range!";
382 }
383 index_entree = LongToSize(std::distance(
384 outgoing_auxinputs->begin(),
385 find(outgoing_auxinputs->begin(), outgoing_auxinputs->end(), graph->nodes[node_index].node_in[0])));
386 }
387 // Determine whether the other input operator of the current node is in the input of the output operator,
388 // and if not, add it to the aux_input of the output operator
389 for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) {
390 exist_in_outgoing_inputs =
391 find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(),
392 graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), graph->nodes[node_index].node_in[j]);
393 if (exist_in_outgoing_inputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) {
394 outgoing_auxinputs->push_back(graph->nodes[node_index].node_in[j]);
395 if (outgoing_auxinputs_index->size() > index_entree) {
396 outgoing_auxinputs_index->push_back(outgoing_auxinputs_index->at(index_entree));
397 } else {
398 MS_LOG(DEBUG) << "Trying to index vector element at index " << index_entree << ", out of range!";
399 }
400 }
401 }
402 // Determine if the aux_input operator of the current node is in the input of the output operator,
403 // and if not, add it to the aux_input of the output operator
404 for (size_t j = 0; j < graph->nodes[node_index].node_in_aux.size(); j++) {
405 exist_in_outgoing_inputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(),
406 graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(),
407 graph->nodes[node_index].node_in_aux[j]);
408 if (exist_in_outgoing_inputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) {
409 outgoing_auxinputs->push_back(graph->nodes[node_index].node_in_aux[j]);
410 outgoing_auxinputs_index->push_back(outgoing_auxinputs_index->at(index_entree));
411 }
412 }
413 } else {
414 if (outgoing_auxinputs_index->size() > index_entree) {
415 (void)outgoing_auxinputs_index->erase(outgoing_auxinputs_index->begin() + index_entree);
416 } else {
417 MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!";
418 }
419 if (outgoing_auxinputs->size() > index_entree) {
420 (void)outgoing_auxinputs->erase(it);
421 } else {
422 MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", which is out of range.";
423 }
424 }
425 }
426 }
427
Eliminate_Aux_Outgoing(size_t node_index,const std::shared_ptr<Graph> & graph)428 void Eliminate_Aux_Outgoing(size_t node_index, const std::shared_ptr<Graph> &graph) {
429 for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) {
430 // Handle the output operator connected to the current node via main edge
431 EliminateAuxOutgoingInput(node_index, graph, i);
432 // Handle the output operator connected to the current node via auxiliary edge
433 EliminateAuxOutgoingAuxInput(node_index, graph, i);
434 }
435 }
436
EraseEliminatedNode(std::vector<size_t> * nodes,const std::shared_ptr<std::vector<size_t>> & index_list)437 static void EraseEliminatedNode(std::vector<size_t> *nodes, const std::shared_ptr<std::vector<size_t>> &index_list) {
438 for (size_t j = nodes->size(); j > 0; j--) {
439 bool IsEliminated = (index_list->at(nodes->at(j - 1)) == SIZE_MAX);
440 if (IsEliminated) {
441 (void)nodes->erase(nodes->begin() + SizeToLong(j) - 1);
442 } else {
443 nodes->at(j - 1) = index_list->at(nodes->at(j - 1));
444 }
445 }
446 }
447
EliminateGraph(const std::shared_ptr<Graph> & graph,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::shared_ptr<std::vector<size_t>> & index_list,const bool dyn_shape_tmp_fix)448 std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
449 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
450 const std::shared_ptr<std::vector<size_t>> &index_list,
451 const bool dyn_shape_tmp_fix) {
452 MS_EXCEPTION_IF_NULL(graph);
453 for (size_t node_index = 0; node_index < graph->nodes.size(); node_index++) {
454 auto type = graph->nodes[node_index].apply.op_type;
455 if (dyn_shape_tmp_fix && type == OperatorType::kRecBatchMatMul) {
456 continue;
457 } else if (EliminateOpType.find(type) != EliminateOpType.end()) {
458 Eliminate_Aux(node_index, graph, eli_list);
459 }
460 }
461 index_list->reserve(graph->nodes.size());
462 for (size_t i = 0; i < graph->nodes.size(); i++) {
463 index_list->push_back(i);
464 }
465 for (size_t i = 0; i < eli_list->size(); i++) {
466 if (eli_list->at(i)[0] >= index_list->size()) {
467 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
468 }
469 index_list->at(eli_list->at(i)[0]) = SIZE_MAX;
470 for (size_t j = eli_list->at(i)[0] + 1; j < index_list->size(); j++) {
471 index_list->at(j)--;
472 }
473 }
474 std::shared_ptr<Graph> new_graph = std::make_shared<Graph>();
475 for (size_t i = 0; i < graph->nodes.size(); i++) {
476 if (index_list->at(i) > SIZE_MAX / 2) {
477 continue;
478 }
479 new_graph->nodes.push_back(graph->nodes[i]);
480 auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
481 EraseEliminatedNode(node_in, index_list);
482 auto *node_in_aux = &new_graph->nodes[index_list->at(i)].node_in_aux;
483 EraseEliminatedNode(node_in_aux, index_list);
484 auto *node_out = &new_graph->nodes[index_list->at(i)].node_out;
485 EraseEliminatedNode(node_out, index_list);
486 }
487 return new_graph;
488 }
489 } // namespace parallel
490 } // namespace mindspore
491