• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "ir/value.h"
25 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h"
26 #include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h"
27 #include "frontend/parallel/ops_info/operator_info.h"
28 
29 namespace mindspore {
30 namespace parallel {
MakeTensor(int64_t n,int64_t c,int64_t h,int64_t w)31 const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w) {
32   TensorParam new_tensor;
33   new_tensor.tensor_type = kFloat32;
34   new_tensor.tensor_shape.shape_n = n;
35   new_tensor.tensor_shape.shape_c = c;
36   new_tensor.tensor_shape.shape_h = h;
37   new_tensor.tensor_shape.shape_w = w;
38   const TensorParam &tensor = new_tensor;
39   return tensor;
40 }
41 
MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t iter_ops)42 Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops) {
43   Graph::NodeType NewOp;
44   NewOp.name = ops[iter_ops]->name();
45   NewOp.info = InfoType::kApplication;
46 
47   auto op_type = ops[iter_ops]->type();
48   auto idx = DictOpType.find(op_type);
49   if (idx == DictOpType.end()) {
50     NewOp.apply.op_type = OperatorType::kRecUnkownType;
51     MS_LOG(INFO) << ops[iter_ops]->name() << ": Unknown operator type " << op_type;
52   } else {
53     NewOp.apply.op_type = DictOpType.at(op_type);
54   }
55 
56   if (ops[iter_ops]->outputs_tensor_info().size() == 0) {
57     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty.";
58   }
59 
60   if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) {
61     NewOp.tensor_parm = MakeTensor(
62       ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
63       ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]);
64   } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) {
65     NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
66                                    ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
67                                    ops[iter_ops]->outputs_tensor_info()[0].shape()[2]);
68   } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
69     NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
70                                    ops[iter_ops]->outputs_tensor_info()[0].shape()[1]);
71   } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) {
72     NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]);
73   } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) {
74     NewOp.tensor_parm = MakeTensor(1, 1, 1, 1);
75   } else {
76     MS_LOG(ERROR) << ops[iter_ops]->name() << ": output tensor shape is unexpected.";
77   }
78 
79   NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp);
80   return NewOp;
81 }
82 
CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Graph::NodeType NewTensor)83 OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
84                                    Graph::NodeType NewTensor) {
85   size_t input_tensor_size = ops[iter_ops]->inputs_tensor_info().size();
86   if (ops[iter_ops]->type() == STACK) {
87     input_tensor_size = 1;
88   }
89   if (input_tensor_size > MAX_INPUT_NUM) {
90     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor num exceeds limit.";
91   }
92 
93   for (size_t iter_input_tensors = 0; iter_input_tensors < input_tensor_size; iter_input_tensors++) {
94     if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) {
95       NewTensor.apply.arguments[iter_input_tensors] =
96         MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
97                    ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
98                    ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2],
99                    ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]);
100     } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 3) {
101       NewTensor.apply.arguments[iter_input_tensors] =
102         MakeTensor(1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
103                    ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
104                    ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2]);
105     } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) {
106       NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor);
107     } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) {
108       NewTensor.apply.arguments[iter_input_tensors] =
109         MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]);
110     } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) {
111       NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1);
112     } else {
113       MS_LOG(ERROR) << ops[iter_ops]->name() << ": input tensor shape is unexpected.";
114     }
115   }
116   return NewTensor.apply;
117 }
118 
Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_input_tensors,Graph::NodeType NewTensor)119 TensorParam Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
120                              const size_t iter_input_tensors, Graph::NodeType NewTensor) {
121   if (NewTensor.apply.op_type == OperatorType::kRecMatMul) {
122     auto attrs = ops[iter_ops]->attrs();
123     bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
124     bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
125     if (transpose_a && (iter_input_tensors == 0)) {
126       NewTensor.apply.arguments[iter_input_tensors] =
127         MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
128                    ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]);
129     } else if (transpose_b && (iter_input_tensors == 1)) {
130       NewTensor.apply.arguments[iter_input_tensors] =
131         MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
132                    ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]);
133     } else {
134       NewTensor.apply.arguments[iter_input_tensors] =
135         MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
136                    ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]);
137     }
138   } else {
139     NewTensor.apply.arguments[iter_input_tensors] =
140       MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
141                  ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]);
142   }
143   return NewTensor.apply.arguments[iter_input_tensors];
144 }
145 
ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names)146 std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
147                                   const std::vector<std::vector<std::string>> &input_tensor_names) {
148   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
149   if (ops.size() > SIZE_MAX / 2) {
150     MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2;
151   }
152 
153   for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) {
154     Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops);
155     graph->nodes.push_back(NewOp);
156   }
157   MakeEdge(input_tensor_names, graph);
158 
159   return graph;
160 }
161 
MakeEdge(const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<Graph> & graph)162 void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph) {
163   for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) {
164     for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) {
165       size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]);
166       if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) {
167         graph->nodes[iter_i].node_in.push_back(head_node_index);
168         graph->nodes[head_node_index].node_out.push_back(iter_i);
169       }
170     }
171   }
172 }
173 
GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> & input_tensor_name,const std::string & input_name)174 size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_name,
175                                   const std::string &input_name) {
176   for (size_t index = 0; index < input_tensor_name.size(); index++) {
177     if (input_tensor_name[index][0] == input_name) {
178       return index;
179     }
180   }
181   MS_LOG(INFO) << "Get index failed, using SIZE_MAX instead";
182   return SIZE_MAX;
183 }
184 
Eliminate_Aux(const size_t node_index,const std::shared_ptr<Graph> & graph,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list)185 void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> &graph,
186                    const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list) {
187   std::vector<size_t> eli;
188   eli.push_back(node_index);
189   for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) {
190     eli.push_back(graph->nodes[node_index].node_out[i]);
191   }
192   eli_list->push_back(eli);
193 
194   for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) {
195     auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out;
196     auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index);
197     if (it != incoming_outputs->end()) {
198       it = incoming_outputs->erase(it);
199       incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), graph->nodes[node_index].node_out.end());
200     }
201   }
202 
203   for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) {
204     auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out;
205     auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index);
206     if (it != aux_incoming_outputs->end()) {
207       it = aux_incoming_outputs->erase(it);
208       aux_incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(),
209                                    graph->nodes[node_index].node_out.end());
210     }
211   }
212 
213   for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) {
214     auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in;
215     auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index);
216     if (it != outgoing_inputs->end()) {
217       if (graph->nodes[node_index].node_in.size() > 0) {
218         outgoing_inputs->at(LongToSize(std::distance(outgoing_inputs->begin(), it))) =
219           graph->nodes[node_index].node_in[0];
220         for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) {
221           graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]);
222         }
223         for (size_t j = 1; j < graph->nodes[node_index].node_in_aux.size(); j++) {
224           graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(
225             graph->nodes[node_index].node_in_aux[j]);
226         }
227       } else {
228         outgoing_inputs->erase(it);
229       }
230     }
231   }
232 }
233 
EliminateGraph(const std::shared_ptr<Graph> & graph,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::shared_ptr<std::vector<size_t>> & index_list)234 std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
235                                       const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
236                                       const std::shared_ptr<std::vector<size_t>> &index_list) {
237   MS_EXCEPTION_IF_NULL(graph);
238   for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
239     auto type = graph->nodes[node_index].apply.op_type;
240     if (ElementWiseOpType.find(type) != ElementWiseOpType.end()) {
241       Eliminate_Aux(node_index, graph, eli_list);
242     }
243   }
244   index_list->reserve(graph->nodes.size());
245   for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) {
246     index_list->push_back(i);
247   }
248   for (size_t i = 0; i < (size_t)eli_list->size(); i++) {
249     if (eli_list->at(i)[0] >= index_list->size()) {
250       MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
251     }
252     index_list->at(eli_list->at(i)[0]) = SIZE_MAX;
253     for (size_t j = eli_list->at(i)[0] + 1; j < (size_t)index_list->size(); j++) {
254       index_list->at(j)--;
255     }
256   }
257   std::shared_ptr<Graph> new_graph = std::make_shared<Graph>();
258   for (size_t i = 0; i < graph->nodes.size(); i++) {
259     if (index_list->at(i) > SIZE_MAX / 2) {
260       continue;
261     }
262     new_graph->nodes.push_back(graph->nodes[i]);
263     auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
264     for (size_t j = node_in->size(); j > 0; j--) {
265       bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX);
266       if (IsEliminated) {
267         (void)node_in->erase(node_in->begin() + SizeToLong(j) - 1);
268       } else {
269         node_in->at(j - 1) = index_list->at(node_in->at(j - 1));
270       }
271     }
272     auto *node_out = &new_graph->nodes[index_list->at(i)].node_out;
273     for (size_t j = node_out->size(); j > 0; j--) {
274       bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX);
275       if (IsEliminated) {
276         (void)node_out->erase(node_out->begin() + SizeToLong(j) - 1);
277       } else {
278         node_out->at(j - 1) = index_list->at(node_out->at(j - 1));
279       }
280     }
281   }
282   return new_graph;
283 }
284 }  // namespace parallel
285 }  // namespace mindspore
286