• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/delegate/npu/pass/npu_insert_transform_pass.h"
17 #include <algorithm>
18 #include <set>
19 #include <string>
20 #include "src/delegate/npu/pass/npu_pass_utils.h"
21 
22 using mindspore::lite::RET_ERROR;
23 using mindspore::lite::RET_OK;
24 
25 namespace mindspore {
26 enum InsertState { InsertNone, PreInsert, PostInsert, BothInsert };
27 std::set<mindspore::schema::PrimitiveType> insert_nodes = {
28   schema::PrimitiveType_Concat,       schema::PrimitiveType_AddFusion, schema::PrimitiveType_Eltwise,
29   schema::PrimitiveType_Activation,   schema::PrimitiveType_Split,     schema::PrimitiveType_PadFusion,
30   schema::PrimitiveType_StridedSlice, schema::PrimitiveType_MulFusion, schema::PrimitiveType_DivFusion};
31 
32 // this pass goal is to minimize subgraphs generated
33 // by inserting nchw2nhwc or nhwc2nchw before or after the operator (e.g. concat, add, etc..) together with
34 // fusion pass. If transpose inserted are more than half of input output, we will insert remaining input
35 // output with transpose and hopefully do a fusion pass. Otherwise, we don't insert anything.
36 
37 // Typically concat accept output from nchw2nhwc, we fill other input with nh2nc and nc2nh so that inputs to concat are
38 // format same and then fusion all nchw2nhwc op.
39 // e.g.
40 // original     (conv->nchw2nhwc, add(format nhwc)) -> concat-> (nhwc2nchw->conv)
41 // current pass (conv->nchw2nhwc, add->nhwc2nchw->nchw2nhwc) -> concat -> (nhwc2nchw->conv)
42 // fusion pass  (conv, add->nhwc2nchw) -> concat -> conv
43 // original 2 cpusubgraph, after 2 pass, only 1 cpu subgraph
44 
45 // Such ops require inputs all have same format, could be nchw or nhwc or other format.
46 // Their inputs outputs may not be 4d, or are already format ok,
47 // so we won't insert nc2nh or nh2nc when op's in ops and out ops contains no nc2nh or nh2nc.
48 // This pass should be run after npu_transform_pass, which insert transpose for nchw-input-limited op like conv2d.
49 
GetInsertState(NPUOp * op)50 int NPUInsertTransformPass::GetInsertState(NPUOp *op) {
51   // filter out irrelevant op
52   if (insert_nodes.find(op->type()) == insert_nodes.end()) {
53     return InsertNone;
54   }
55 
56   // current op is target op
57   // use out ops to count how many out lines from current op
58   std::vector<mindspore::MSTensor> inputs = NPUPassUtils::GetNonConstInputs(op);
59   size_t in_out_tensor_num =
60     inputs.size() + std::max(std::max(op->out_ops().size(), static_cast<size_t>(1)), op->outputs().size());
61   size_t transpose_input_num = 0;
62   size_t transpose_output_num = 0;
63   size_t graph_input_num = 0;
64   size_t graph_output_num = 0;
65   bool need_pre_insert = false;
66   bool need_post_insert = false;
67   // count number of input tensor from nc2nh and output tensor to nh2nc
68   for (size_t i = 0; i < inputs.size(); ++i) {
69     auto in_op = NPUPassUtils::OpInputFromOp(op, inputs.at(i));
70     if (NPUPassUtils::IsNchw2Nhwc(in_op)) {
71       transpose_input_num++;
72     } else {
73       need_pre_insert = true;
74     }
75     if (in_op == nullptr) {
76       graph_input_num++;
77     }
78   }
79   if (op->out_ops().empty()) {
80     need_post_insert = true;
81   }
82   if (op->outputs().size() > op->out_ops().size()) {
83     graph_output_num = op->outputs().size() - op->out_ops().size();
84   }
85   for (const auto out_op : op->out_ops()) {
86     if (NPUPassUtils::IsNhwc2Nchw(out_op)) {
87       transpose_output_num++;
88     } else {
89       need_post_insert = true;
90     }
91   }
92 
93   // won't insert any thing if num of transpose tensor is smaller than half of total op inputs and op outputs, unless
94   // current op is the graph input or output op, since we should avoid to build a single op subgraph in this case.
95   // won't insert if total input output are all transpose tensor, the fusion pass will handle this.
96   size_t transpose_tensor_num = transpose_input_num + transpose_output_num;
97   size_t connected_in_out_tensor_num = in_out_tensor_num - graph_output_num - graph_input_num;
98   if (transpose_tensor_num == 0 || transpose_tensor_num * 2 < connected_in_out_tensor_num ||
99       transpose_tensor_num == in_out_tensor_num) {
100     return InsertNone;
101   }
102   InsertState ret =
103     (need_pre_insert && need_post_insert)
104       ? BothInsert
105       : ((need_pre_insert && !need_post_insert) ? PreInsert
106                                                 : ((!need_pre_insert && need_post_insert) ? PostInsert : InsertNone));
107 
108   return ret;
109 }
110 
InsertNode(NPUOp * op,NPUOp * post_op,size_t post_input_index,std::vector<NPUOp * > * trans_ops)111 int NPUInsertTransformPass::InsertNode(NPUOp *op, NPUOp *post_op, size_t post_input_index,
112                                        std::vector<NPUOp *> *trans_ops) {
113   // Op and post_op can't be nullptr at the same time.
114   std::string op_name;
115   std::vector<mindspore::MSTensor> in_tensors;
116   std::vector<NPUOp *> out_ops;
117   // If post_op equals nullptr, op is the output of whole graph.
118   if (post_op != nullptr) {
119     out_ops.push_back(post_op);
120     op_name = post_op->name() + "_pre";
121     in_tensors.push_back(post_op->inputs().at(post_input_index));
122   }
123   std::vector<NPUOp *> in_ops;
124   // If op equals nullptr, post_op is the input of whole graph.
125   if (op != nullptr && !op->outputs().empty()) {
126     in_ops.push_back(op);
127     op_name = op->name() + "_post";
128     in_tensors.resize(op->outputs().size());
129     std::copy(op->outputs().begin(), op->outputs().end(), in_tensors.begin());
130   }
131   for (auto i = 0; i < in_tensors.size(); ++i) {
132     auto in_tensor = in_tensors[i];
133     auto nhwc_shape = in_tensor.Shape();
134     if (nhwc_shape.size() == 0) {
135       continue;
136     } else if (nhwc_shape.size() < 4) {
137       MS_LOG(ERROR) << "nhwc_shape size < " << 4;
138       return RET_ERROR;
139     }
140     std::vector<int64_t> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
141 
142     auto nh2nc_name = op_name + "_nh2nc_" + std::to_string(total++);
143     auto nh2nc_tensor =
144       mindspore::MSTensor::CreateTensor(nh2nc_name + "/output0", in_tensor.DataType(), nchw_shape, nullptr, 0);
145     if (nh2nc_tensor == nullptr) {
146       MS_LOG(ERROR) << "New nchw tensor failed when inserting nchw2nhwc op.";
147       return RET_ERROR;
148     }
149     nh2nc_tensor->SetTensorName(nh2nc_name + "/output0");
150     std::vector<mindspore::MSTensor> nh2nc_tensors = {*nh2nc_tensor};
151     all_tensors_->push_back(nh2nc_tensor);
152 
153     auto nc2nh_name = op_name + "_nc2nh_" + std::to_string(total++);
154     auto nc2nh_tensor =
155       mindspore::MSTensor::CreateTensor(nc2nh_name + "/output0", in_tensor.DataType(), nhwc_shape, nullptr, 0);
156     if (nc2nh_tensor == nullptr) {
157       MS_LOG(ERROR) << "New nhwc tensor failed when inserting nhwc2nchw op.";
158       return RET_ERROR;
159     }
160     std::vector<mindspore::MSTensor> nc2nh_tensors = {*nc2nh_tensor};
161     all_tensors_->push_back(nc2nh_tensor);
162 
163     auto *nh2nc_op = NPUPassUtils::CreateNhwc2NchwOp({in_tensor}, nh2nc_tensors, nh2nc_name);
164     trans_ops->push_back(nh2nc_op);
165 
166     auto *nc2nh_op = NPUPassUtils::CreateNchw2NhwcOp(nh2nc_tensors, nc2nh_tensors, nc2nh_name);
167     trans_ops->push_back(nc2nh_op);
168 
169     NPUPassUtils::UpdateOp(nh2nc_op, in_ops, {nc2nh_op}, {in_tensor}, nh2nc_tensors);
170     NPUPassUtils::UpdateOp(nc2nh_op, {nh2nc_op}, out_ops, {nh2nc_tensors[0]}, nc2nh_tensors);
171     if (op != nullptr) {
172       NPUPassUtils::UpdateNH2NCTransNodePreOp(op, nh2nc_op, post_op);
173     }
174     if (post_op != nullptr) {
175       NPUPassUtils::UpdateNC2NHTransNodePostOp(op, nc2nh_op, post_op);
176     } else {
177       // post_op nullptr mean output, we remain graph output tensor name unchanged
178       auto graph_output_name = in_tensor.Name();
179       nc2nh_tensor->SetTensorName(graph_output_name + "_after_" + name_);
180     }
181   }
182   return RET_OK;
183 }
184 
InsertForInputTensor(NPUOp * op,size_t in_tensor_index,NPUOp * pre_op,std::vector<NPUOp * > * trans_ops)185 int NPUInsertTransformPass::InsertForInputTensor(NPUOp *op, size_t in_tensor_index, NPUOp *pre_op,
186                                                  std::vector<NPUOp *> *trans_ops) {
187   // insert transpose nodes before target ops
188   return InsertNode(pre_op, op, in_tensor_index, trans_ops);
189 }
190 
InsertForOutputTensor(NPUOp * op,NPUOp * post_op,size_t post_in_tensor_index,std::vector<NPUOp * > * trans_ops)191 int NPUInsertTransformPass::InsertForOutputTensor(NPUOp *op, NPUOp *post_op, size_t post_in_tensor_index,
192                                                   std::vector<NPUOp *> *trans_ops) {
193   // insert transpose nodes after target ops
194   return InsertNode(op, post_op, post_in_tensor_index, trans_ops);
195 }
196 
InsertPreNodes(NPUOp * op,std::vector<NPUOp * > * trans_ops)197 int NPUInsertTransformPass::InsertPreNodes(NPUOp *op, std::vector<NPUOp *> *trans_ops) {
198   int ret = RET_OK;
199   auto inputs = NPUPassUtils::GetNonConstInputs(op);
200   for (auto tensor : inputs) {
201     auto pre_op = NPUPassUtils::OpInputFromOp(op, tensor);
202     if (NPUPassUtils::IsNchw2Nhwc(pre_op)) {
203       continue;
204     }
205     // if this tensor is input of graph, pre_op is nullptr.
206     auto it = find(op->inputs().begin(), op->inputs().end(), tensor);
207     if (it == op->inputs().end()) {
208       MS_LOG(ERROR) << "Find in tensor index error";
209       return RET_ERROR;
210     }
211     size_t index = it - op->inputs().begin();
212     ret = InsertForInputTensor(op, index, pre_op, trans_ops);
213     if (ret != RET_OK) {
214       MS_LOG(ERROR) << "Insert nhwc2nchw op and nchw2nhwc op before op " << op->name() << " failed.";
215       return ret;
216     }
217   }
218   return ret;
219 }
220 
InsertPostNodes(NPUOp * op,std::vector<NPUOp * > * trans_ops)221 int NPUInsertTransformPass::InsertPostNodes(NPUOp *op, std::vector<NPUOp *> *trans_ops) {
222   int ret = RET_OK;
223 
224   for (const auto post_op : op->out_ops()) {
225     if (NPUPassUtils::IsNhwc2Nchw(post_op)) {
226       continue;
227     }
228     auto post_op_in_tensors = post_op->inputs();
229     // op's out tensor is one of post_op's input tensor
230     auto it = std::find(post_op_in_tensors.begin(), post_op_in_tensors.end(), op->outputs().at(0));
231     if (it == post_op_in_tensors.end()) {
232       return RET_ERROR;
233     }
234     size_t input_index = it - post_op_in_tensors.begin();
235     ret = InsertForOutputTensor(op, post_op, input_index, trans_ops);
236     if (ret != RET_OK) {
237       MS_LOG(ERROR) << "Insert nhwc2nchw op and nchw2nhwc op after op " << op->name() << " failed.";
238       return ret;
239     }
240   }
241   if (op->outputs().size() > op->out_ops().size()) {
242     // op out is graph output
243     ret = InsertForOutputTensor(op, nullptr, 0, trans_ops);
244     if (ret != RET_OK) {
245       MS_LOG(ERROR) << "Insert nhwc2nchw op and nchw2nhwc op after op " << op->name() << " failed.";
246       return ret;
247     }
248   }
249   return ret;
250 }
251 
Run(NPUGraph * subgraph)252 int NPUInsertTransformPass::Run(NPUGraph *subgraph) {
253   all_ops_ = subgraph->GetOps();
254   all_tensors_ = subgraph->GetInsertTensors();
255   std::vector<NPUOp *> insert_ops;
256   for (int j = 0; j < 2; ++j) {
257     for (size_t i = 0; i < all_ops_->size(); i++) {
258       auto op = (*all_ops_)[i];
259       auto insert_state = GetInsertState(op);
260       insert_ops.clear();
261       // If the every output op is nhwc2nchw, insert
262       // modify loop index add post_ops.size() to the next op in the origin vector
263       switch (insert_state) {
264         case PreInsert: {
265           auto ret = InsertPreNodes(op, &insert_ops);
266           if (ret != RET_OK) {
267             MS_LOG(ERROR) << "Insert nhwc2nchw op and nchw2nhwc op before op " << op->name() << " failed.";
268             return RET_ERROR;
269           }
270           all_ops_->insert(all_ops_->begin() + i, insert_ops.begin(), insert_ops.end());
271           i += insert_ops.size();
272           break;
273         }
274         case PostInsert: {
275           auto ret = InsertPostNodes(op, &insert_ops);
276           if (ret != RET_OK) {
277             MS_LOG(ERROR) << "Insert nhwc2nchw op and nchw2nhwc op after op " << op->name() << " failed.";
278             return RET_ERROR;
279           }
280           all_ops_->insert(all_ops_->begin() + i + 1, insert_ops.begin(), insert_ops.end());
281           i += insert_ops.size();
282           break;
283         }
284         case BothInsert: {
285           auto ret = InsertPreNodes(op, &insert_ops);
286           if (ret != RET_OK) {
287             MS_LOG(ERROR) << "Insert nhwc2nchw op and nchw2nhwc op before op " << op->name() << " failed.";
288             return RET_ERROR;
289           }
290           all_ops_->insert(all_ops_->begin() + i, insert_ops.begin(), insert_ops.end());
291           i += insert_ops.size();
292 
293           insert_ops.clear();
294           ret = InsertPostNodes(op, &insert_ops);
295           if (ret != RET_OK) {
296             MS_LOG(ERROR) << "Insert nhwc2nchw op and nchw2nhwc op after op " << op->name() << " failed.";
297             return RET_ERROR;
298           }
299           all_ops_->insert(all_ops_->begin() + i + 1, insert_ops.begin(), insert_ops.end());
300           i += insert_ops.size();
301           break;
302         }
303         default:
304           MS_LOG(DEBUG) << "Insert Nothing on op " << op->name();
305       }
306     }
307   }
308   return RET_OK;
309 }
310 }  // namespace mindspore
311