1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "ir/value.h"
25 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h"
26 #include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h"
27 #include "frontend/parallel/ops_info/operator_info.h"
28
29 namespace mindspore {
30 namespace parallel {
MakeTensor(int64_t n,int64_t c,int64_t h,int64_t w)31 const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w) {
32 TensorParam new_tensor;
33 new_tensor.tensor_type = kFloat32;
34 new_tensor.tensor_shape.shape_n = n;
35 new_tensor.tensor_shape.shape_c = c;
36 new_tensor.tensor_shape.shape_h = h;
37 new_tensor.tensor_shape.shape_w = w;
38 const TensorParam &tensor = new_tensor;
39 return tensor;
40 }
41
MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t iter_ops)42 Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops) {
43 Graph::NodeType NewOp;
44 NewOp.name = ops[iter_ops]->name();
45 NewOp.info = InfoType::kApplication;
46
47 auto op_type = ops[iter_ops]->type();
48 auto idx = DictOpType.find(op_type);
49 if (idx == DictOpType.end()) {
50 NewOp.apply.op_type = OperatorType::kRecUnkownType;
51 MS_LOG(INFO) << ops[iter_ops]->name() << ": Unknown operator type " << op_type;
52 } else {
53 NewOp.apply.op_type = DictOpType.at(op_type);
54 }
55
56 if (ops[iter_ops]->outputs_tensor_info().size() == 0) {
57 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty.";
58 }
59
60 if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) {
61 NewOp.tensor_parm = MakeTensor(
62 ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
63 ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]);
64 } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) {
65 NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
66 ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
67 ops[iter_ops]->outputs_tensor_info()[0].shape()[2]);
68 } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
69 NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
70 ops[iter_ops]->outputs_tensor_info()[0].shape()[1]);
71 } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) {
72 NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]);
73 } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) {
74 NewOp.tensor_parm = MakeTensor(1, 1, 1, 1);
75 } else {
76 MS_LOG(ERROR) << ops[iter_ops]->name() << ": output tensor shape is unexpected.";
77 }
78
79 NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp);
80 return NewOp;
81 }
82
CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Graph::NodeType NewTensor)83 OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
84 Graph::NodeType NewTensor) {
85 size_t input_tensor_size = ops[iter_ops]->inputs_tensor_info().size();
86 if (ops[iter_ops]->type() == STACK) {
87 input_tensor_size = 1;
88 }
89 if (input_tensor_size > MAX_INPUT_NUM) {
90 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor num exceeds limit.";
91 }
92
93 for (size_t iter_input_tensors = 0; iter_input_tensors < input_tensor_size; iter_input_tensors++) {
94 if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) {
95 NewTensor.apply.arguments[iter_input_tensors] =
96 MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
97 ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
98 ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2],
99 ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]);
100 } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 3) {
101 NewTensor.apply.arguments[iter_input_tensors] =
102 MakeTensor(1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
103 ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
104 ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2]);
105 } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) {
106 NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor);
107 } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) {
108 NewTensor.apply.arguments[iter_input_tensors] =
109 MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]);
110 } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) {
111 NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1);
112 } else {
113 MS_LOG(ERROR) << ops[iter_ops]->name() << ": input tensor shape is unexpected.";
114 }
115 }
116 return NewTensor.apply;
117 }
118
Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_input_tensors,Graph::NodeType NewTensor)119 TensorParam Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
120 const size_t iter_input_tensors, Graph::NodeType NewTensor) {
121 if (NewTensor.apply.op_type == OperatorType::kRecMatMul) {
122 auto attrs = ops[iter_ops]->attrs();
123 bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
124 bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
125 if (transpose_a && (iter_input_tensors == 0)) {
126 NewTensor.apply.arguments[iter_input_tensors] =
127 MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
128 ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]);
129 } else if (transpose_b && (iter_input_tensors == 1)) {
130 NewTensor.apply.arguments[iter_input_tensors] =
131 MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
132 ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]);
133 } else {
134 NewTensor.apply.arguments[iter_input_tensors] =
135 MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
136 ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]);
137 }
138 } else {
139 NewTensor.apply.arguments[iter_input_tensors] =
140 MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
141 ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]);
142 }
143 return NewTensor.apply.arguments[iter_input_tensors];
144 }
145
ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names)146 std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
147 const std::vector<std::vector<std::string>> &input_tensor_names) {
148 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
149 if (ops.size() > SIZE_MAX / 2) {
150 MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2;
151 }
152
153 for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) {
154 Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops);
155 graph->nodes.push_back(NewOp);
156 }
157 MakeEdge(input_tensor_names, graph);
158
159 return graph;
160 }
161
MakeEdge(const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<Graph> & graph)162 void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph) {
163 for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) {
164 for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) {
165 size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]);
166 if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) {
167 graph->nodes[iter_i].node_in.push_back(head_node_index);
168 graph->nodes[head_node_index].node_out.push_back(iter_i);
169 }
170 }
171 }
172 }
173
GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> & input_tensor_name,const std::string & input_name)174 size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_name,
175 const std::string &input_name) {
176 for (size_t index = 0; index < input_tensor_name.size(); index++) {
177 if (input_tensor_name[index][0] == input_name) {
178 return index;
179 }
180 }
181 MS_LOG(INFO) << "Get index failed, using SIZE_MAX instead";
182 return SIZE_MAX;
183 }
184
Eliminate_Aux(const size_t node_index,const std::shared_ptr<Graph> & graph,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list)185 void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> &graph,
186 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list) {
187 std::vector<size_t> eli;
188 eli.push_back(node_index);
189 for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) {
190 eli.push_back(graph->nodes[node_index].node_out[i]);
191 }
192 eli_list->push_back(eli);
193
194 for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) {
195 auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out;
196 auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index);
197 if (it != incoming_outputs->end()) {
198 it = incoming_outputs->erase(it);
199 incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), graph->nodes[node_index].node_out.end());
200 }
201 }
202
203 for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) {
204 auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out;
205 auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index);
206 if (it != aux_incoming_outputs->end()) {
207 it = aux_incoming_outputs->erase(it);
208 aux_incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(),
209 graph->nodes[node_index].node_out.end());
210 }
211 }
212
213 for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) {
214 auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in;
215 auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index);
216 if (it != outgoing_inputs->end()) {
217 if (graph->nodes[node_index].node_in.size() > 0) {
218 outgoing_inputs->at(LongToSize(std::distance(outgoing_inputs->begin(), it))) =
219 graph->nodes[node_index].node_in[0];
220 for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) {
221 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]);
222 }
223 for (size_t j = 1; j < graph->nodes[node_index].node_in_aux.size(); j++) {
224 graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(
225 graph->nodes[node_index].node_in_aux[j]);
226 }
227 } else {
228 outgoing_inputs->erase(it);
229 }
230 }
231 }
232 }
233
EliminateGraph(const std::shared_ptr<Graph> & graph,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::shared_ptr<std::vector<size_t>> & index_list)234 std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
235 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
236 const std::shared_ptr<std::vector<size_t>> &index_list) {
237 MS_EXCEPTION_IF_NULL(graph);
238 for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
239 auto type = graph->nodes[node_index].apply.op_type;
240 if (ElementWiseOpType.find(type) != ElementWiseOpType.end()) {
241 Eliminate_Aux(node_index, graph, eli_list);
242 }
243 }
244 index_list->reserve(graph->nodes.size());
245 for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) {
246 index_list->push_back(i);
247 }
248 for (size_t i = 0; i < (size_t)eli_list->size(); i++) {
249 if (eli_list->at(i)[0] >= index_list->size()) {
250 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
251 }
252 index_list->at(eli_list->at(i)[0]) = SIZE_MAX;
253 for (size_t j = eli_list->at(i)[0] + 1; j < (size_t)index_list->size(); j++) {
254 index_list->at(j)--;
255 }
256 }
257 std::shared_ptr<Graph> new_graph = std::make_shared<Graph>();
258 for (size_t i = 0; i < graph->nodes.size(); i++) {
259 if (index_list->at(i) > SIZE_MAX / 2) {
260 continue;
261 }
262 new_graph->nodes.push_back(graph->nodes[i]);
263 auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
264 for (size_t j = node_in->size(); j > 0; j--) {
265 bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX);
266 if (IsEliminated) {
267 (void)node_in->erase(node_in->begin() + SizeToLong(j) - 1);
268 } else {
269 node_in->at(j - 1) = index_list->at(node_in->at(j - 1));
270 }
271 }
272 auto *node_out = &new_graph->nodes[index_list->at(i)].node_out;
273 for (size_t j = node_out->size(); j > 0; j--) {
274 bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX);
275 if (IsEliminated) {
276 (void)node_out->erase(node_out->begin() + SizeToLong(j) - 1);
277 } else {
278 node_out->at(j - 1) = index_list->at(node_out->at(j - 1));
279 }
280 }
281 }
282 return new_graph;
283 }
284 } // namespace parallel
285 } // namespace mindspore
286