1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef RUNTIME_PASS_CLIP
18 #include "src/litert/runtime_shape_fusion_pass.h"
19 #include <set>
20 #include <queue>
21 #include <algorithm>
22 #include "include/errorcode.h"
23 #include "src/common/log_adapter.h"
24 #include "nnacl/op_base.h"
25
26 namespace mindspore::lite {
27 namespace {
28 constexpr size_t kInitialSize = 1024;
29 } // namespace
ConvertToShapeFusion(LiteGraph::Node * node)30 int ShapeFusionPass::ConvertToShapeFusion(LiteGraph::Node *node) {
31 MS_ASSERT(node != nullptr);
32 auto input_tensor = src_tensors_->at(node->input_indices_.front());
33 MS_CHECK_TRUE_RET(input_tensor != nullptr, RET_ERROR);
34 auto shape = input_tensor->shape();
35 if (shape.empty() || std::find(shape.begin(), shape.end(), -1) != shape.end()) {
36 MS_LOG(INFO) << "The input shape is invalid.";
37 return RET_ERROR;
38 }
39
40 flatbuffers::FlatBufferBuilder fbb(kInitialSize);
41 auto val_offset = schema::CreateCustomDirect(fbb, "ShapeFusion");
42 auto prim_offset =
43 schema::CreatePrimitive(fbb, static_cast<schema::PrimitiveType>(PrimType::PrimType_Custom), val_offset.o);
44 fbb.Finish(prim_offset);
45 void *prim = malloc(fbb.GetSize());
46 if (prim == nullptr) {
47 MS_LOG(ERROR) << "malloc primitive failed.";
48 return RET_ERROR;
49 }
50 memcpy(prim, fbb.GetBufferPointer(), fbb.GetSize());
51 lite_model_->node_bufs_.push_back(prim);
52 fbb.Clear();
53
54 auto shape_fusion_prim = flatbuffers::GetRoot<schema::Primitive>(prim);
55 if(shape_fusion_prim == nullptr){
56 free(prim);
57 MS_LOG(ERROR) << "shape_fusion_prim is nullptr";
58 return RET_ERROR;
59 }
60 ShapeFusionMatrix shape_fusion_matrix(shape.size());
61 if(node->output_indices_.empty()){
62 free(prim);
63 MS_LOG(ERROR) << "node->output_indices_ is empty";
64 return RET_ERROR;
65 }
66 shape_fusion_matrices_[node->output_indices_.front()] = shape_fusion_matrix;
67 auto shape_fusion_matrix_tensor = BuildTensorFromShapeFusionMatrix(shape_fusion_matrix);
68 if(shape_fusion_matrix_tensor == nullptr){
69 free(prim);
70 MS_LOG(ERROR) << "shape_fusion_matrix_tensor is nullptr";
71 return RET_ERROR;
72 }
73
74 node->name_ += "_fusion";
75 node->primitive_ = shape_fusion_prim;
76 node->node_type_ = PrimType::PrimType_Inner_ShapeFusion;
77 node->input_indices_.push_back(src_tensors_->size());
78 src_tensors_->push_back(shape_fusion_matrix_tensor);
79 return RET_OK;
80 }
81
BuildTensorFromShapeFusionMatrix(const ShapeFusionMatrix & shape_fusion_matrix)82 Tensor *ShapeFusionPass::BuildTensorFromShapeFusionMatrix(const ShapeFusionMatrix &shape_fusion_matrix) {
83 MS_CHECK_TRUE_RET(!shape_fusion_matrix.shape_matrix.empty(), nullptr);
84 std::vector<int> matrix_shape;
85 if (shape_fusion_matrix.shape_matrix.size() != 1 || !shape_fusion_matrix.scalar) {
86 matrix_shape.push_back(static_cast<int>(shape_fusion_matrix.shape_matrix.size()));
87 }
88 matrix_shape.push_back(static_cast<int>(shape_fusion_matrix.shape_matrix.front().size()));
89 auto tensor = new (std::nothrow) Tensor(kNumberTypeFloat32, matrix_shape, NUM_OF_FORMAT, Category::CONST_TENSOR);
90 MS_CHECK_TRUE_RET(tensor != nullptr, nullptr);
91 auto matrix_data = tensor->MutableData();
92 if (matrix_data == nullptr) {
93 MS_LOG(ERROR) << "Mutable data failed for tensor: " << tensor->tensor_name();
94 delete tensor;
95 return nullptr;
96 }
97 for (size_t row = 0; row < shape_fusion_matrix.shape_matrix.size(); row++) {
98 auto dst_data = reinterpret_cast<float *>(matrix_data) + row * shape_fusion_matrix.shape_matrix.front().size();
99 memcpy(dst_data, shape_fusion_matrix.shape_matrix.at(row).data(),
100 shape_fusion_matrix.shape_matrix.front().size() * sizeof(float));
101 }
102 return tensor;
103 }
104
FusePostNodes(LiteGraph::Node * node,size_t subgraph_index)105 int ShapeFusionPass::FusePostNodes(LiteGraph::Node *node, size_t subgraph_index) {
106 // fuse arithmetic/concat/gather/squeeze/unsqueeze/shape/cast
107 MS_ASSERT(node != nullptr);
108 std::queue<LiteGraph::Node *> candidate_nodes;
109 auto output_index = node->output_indices_.front();
110 MS_CHECK_TRUE_RET(used_nodes_.find(output_index) != used_nodes_.end(), RET_ERROR);
111 std::vector<uint32_t> visited_outputs;
112 for (auto out_node : used_nodes_[output_index]) {
113 if (CheckCanFused(node, out_node, output_index, subgraph_index)) {
114 candidate_nodes.push(out_node);
115 }
116 visited_outputs.push_back(output_index);
117 }
118
119 while (!candidate_nodes.empty()) {
120 auto output_node = candidate_nodes.front();
121 candidate_nodes.pop();
122 std::vector<uint32_t> used_outputs;
123 if (DoFuse(node, output_node, &used_outputs, subgraph_index) != RET_OK) {
124 MS_LOG(WARNING) << "Fused to shape fusion failed: " << output_node->name_;
125 continue;
126 }
127 // remove unused input and output
128 for (auto original_output : used_outputs) {
129 MS_CHECK_TRUE_RET(used_nodes_.find(original_output) != used_nodes_.end(), RET_ERROR);
130 if (used_nodes_[original_output].empty()) {
131 auto remove_itr = std::find(node->output_indices_.begin(), node->output_indices_.end(), original_output);
132 if (remove_itr == node->output_indices_.end()) {
133 MS_LOG(ERROR) << "can not find original output";
134 return RET_ERROR;
135 }
136 node->output_indices_.erase(remove_itr);
137 node->input_indices_.erase(node->input_indices_.begin() + (remove_itr - node->output_indices_.begin()) + 1);
138 }
139 }
140
141 for (auto idx : node->output_indices_) {
142 MS_CHECK_TRUE_RET(used_nodes_.find(idx) != used_nodes_.end(), RET_ERROR);
143 if (std::find(visited_outputs.begin(), visited_outputs.end(), idx) != visited_outputs.end()) {
144 continue;
145 }
146 visited_outputs.push_back(idx);
147 for (auto out_node : used_nodes_[idx]) {
148 if (CheckCanFused(node, out_node, idx, subgraph_index)) {
149 candidate_nodes.push(out_node);
150 }
151 }
152 }
153 }
154 return RET_OK;
155 }
156
CheckArithmetic(const LiteGraph::Node * shape_fusion,const LiteGraph::Node * post_node,uint32_t input_idx)157 bool ShapeFusionPass::CheckArithmetic(const LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node,
158 uint32_t input_idx) {
159 MS_ASSERT(shape_fusion != nullptr && post_node != nullptr);
160 auto type = post_node->node_type_;
161 if (is_div_ && type != schema::PrimitiveType_DivFusion) {
162 // couldn't fuse add/sub/mul+div or add/sub/div+mul, because it maybe change indivisible to divisible.
163 return false;
164 }
165 is_div_ |= (type == schema::PrimitiveType_DivFusion);
166 MS_CHECK_TRUE_RET(post_node->input_indices_.size() == kInputSize1, false);
167 auto input1_index =
168 post_node->input_indices_.at(0) == input_idx ? post_node->input_indices_.at(1) : post_node->input_indices_.at(0);
169 auto tensor = src_tensors_->at(input1_index);
170 MS_CHECK_TRUE_RET(tensor != nullptr, false);
171 if (tensor->IsConst()) {
172 return true;
173 }
174 auto shape_fusion_outputs = shape_fusion->output_indices_;
175 auto fused_output =
176 std::find(shape_fusion_outputs.begin(), shape_fusion_outputs.end(), input1_index) != shape_fusion_outputs.end();
177 return fused_output && (type == schema::PrimitiveType_AddFusion || type == schema::PrimitiveType_SubFusion);
178 }
179
CheckCanFused(const LiteGraph::Node * shape_fusion,const LiteGraph::Node * post_node,uint32_t input_idx,size_t subgraph_index)180 bool ShapeFusionPass::CheckCanFused(const LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node,
181 uint32_t input_idx, size_t subgraph_index) {
182 MS_ASSERT(shape_fusion != nullptr && post_node != nullptr);
183 MS_CHECK_TRUE_RET(subgraph_index < lite_model_->graph_.sub_graphs_.size(), false);
184 auto subgraph = lite_model_->graph_.sub_graphs_.at(subgraph_index);
185 MS_CHECK_TRUE_RET(subgraph != nullptr, false);
186 auto &subgraph_node_indices = subgraph->node_indices_;
187 bool belong_to_current_subgraph = std::any_of(subgraph_node_indices.begin(), subgraph_node_indices.end(),
188 [&](uint32_t idx) { return all_nodes_->at(idx) == post_node; });
189 if (!belong_to_current_subgraph) {
190 return false;
191 }
192 auto shape_fusion_outputs = shape_fusion->output_indices_;
193 switch (post_node->node_type_) {
194 case schema::PrimitiveType_Cast: {
195 MS_CHECK_TRUE_RET(post_node->input_indices_.size() == kInputSize1, false);
196 auto dst_type_tensor = src_tensors_->at(post_node->input_indices_.at(1));
197 MS_CHECK_TRUE_RET(dst_type_tensor != nullptr && dst_type_tensor->data() != nullptr, false);
198 auto data_type = reinterpret_cast<int *>(dst_type_tensor->data())[0];
199 return data_type == kNumberTypeInt || data_type == kNumberTypeInt32;
200 }
201 case schema::PrimitiveType_AddFusion:
202 case schema::PrimitiveType_SubFusion:
203 case schema::PrimitiveType_MulFusion:
204 case schema::PrimitiveType_DivFusion:
205 return CheckArithmetic(shape_fusion, post_node, input_idx);
206 case schema::PrimitiveType_Concat: {
207 bool is_supported =
208 std::all_of(post_node->input_indices_.begin(), post_node->input_indices_.end(), [&](uint32_t idx) {
209 auto tensor = src_tensors_->at(idx);
210 return tensor->IsConst() ||
211 std::find(shape_fusion_outputs.begin(), shape_fusion_outputs.end(), idx) != shape_fusion_outputs.end();
212 });
213 return is_supported;
214 }
215 case schema::PrimitiveType_Gather:
216 case schema::PrimitiveType_Squeeze:
217 case schema::PrimitiveType_Unsqueeze:
218 case schema::PrimitiveType_Shape:
219 return true;
220 default:
221 break;
222 }
223 return false;
224 }
225
DoFuse(LiteGraph::Node * shape_fusion,const LiteGraph::Node * post_node,std::vector<uint32_t> * input_indices,size_t subgraph_index)226 int ShapeFusionPass::DoFuse(LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node,
227 std::vector<uint32_t> *input_indices, size_t subgraph_index) {
228 MS_ASSERT(shape_fusion != nullptr && post_node != nullptr && input_indices != nullptr);
229 ShapeFusionMatrix shape_fusion_matrix;
230 auto type = post_node->node_type_;
231 if (type == schema::PrimitiveType_AddFusion || type == schema::PrimitiveType_SubFusion ||
232 type == schema::PrimitiveType_MulFusion || type == schema::PrimitiveType_DivFusion ||
233 type == schema::PrimitiveType_Concat) {
234 if (GenerateFusedShapeFusionMatrix(shape_fusion, post_node, input_indices, &shape_fusion_matrix) != RET_OK) {
235 MS_LOG(WARNING) << "GenerateFusedShapeMatrix failed while fuse op: " << post_node->name_;
236 return RET_ERROR;
237 }
238 } else {
239 auto input_index = post_node->input_indices_.front();
240 MS_CHECK_TRUE_RET(shape_fusion_matrices_.find(input_index) != shape_fusion_matrices_.end(), RET_ERROR);
241 shape_fusion_matrix = shape_fusion_matrices_[input_index];
242 input_indices->push_back(input_index);
243 if (UpdateShapeFusionMatrix(post_node, &shape_fusion_matrix) != RET_OK) {
244 MS_LOG(WARNING) << "UpdateShapeMatrix failed while fuse op: " << post_node->name_;
245 return RET_ERROR;
246 }
247 }
248
249 // generate matrix_tensor, and update input_indices and output_indices
250 auto tensor = BuildTensorFromShapeFusionMatrix(shape_fusion_matrix);
251 MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
252 shape_fusion->input_indices_.push_back(src_tensors_->size());
253 src_tensors_->push_back(tensor);
254 auto output_index = post_node->output_indices_.front();
255 shape_fusion->output_indices_.push_back(output_index);
256 shape_fusion_matrices_[output_index] = shape_fusion_matrix;
257
258 MS_CHECK_TRUE_RET(subgraph_index < lite_model_->graph_.sub_graphs_.size(), RET_ERROR);
259 auto subgraph = lite_model_->graph_.sub_graphs_.at(subgraph_index);
260 MS_CHECK_TRUE_RET(subgraph != nullptr, RET_ERROR);
261 auto &subgraph_node_indices = subgraph->node_indices_;
262 size_t node_index = std::find(all_nodes_->begin(), all_nodes_->end(), post_node) - all_nodes_->begin();
263 MS_CHECK_TRUE_RET(node_index != all_nodes_->size(), RET_ERROR);
264 auto indice_itr = std::find(subgraph_node_indices.begin(), subgraph_node_indices.end(), node_index);
265 MS_CHECK_TRUE_RET(indice_itr != subgraph_node_indices.end(), RET_ERROR);
266 subgraph_node_indices.erase(indice_itr);
267 for (auto idx : *input_indices) {
268 MS_CHECK_TRUE_RET(used_nodes_.find(idx) != used_nodes_.end(), RET_ERROR);
269 auto &used_nodes = used_nodes_[idx];
270 auto itr = std::find(used_nodes.begin(), used_nodes.end(), post_node);
271 MS_CHECK_TRUE_RET(itr != used_nodes.end(), RET_ERROR);
272 used_nodes.erase(itr);
273 }
274 return RET_OK;
275 }
276
GenerateFusedShapeFusionMatrix(LiteGraph::Node * shape_fusion,const LiteGraph::Node * post_node,std::vector<uint32_t> * input_indices,ShapeFusionMatrix * shape_fusion_matrix)277 int ShapeFusionPass::GenerateFusedShapeFusionMatrix(LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node,
278 std::vector<uint32_t> *input_indices,
279 ShapeFusionMatrix *shape_fusion_matrix) {
280 MS_ASSERT(shape_fusion != nullptr && post_node != nullptr && shape_fusion_matrix != nullptr);
281 std::vector<uint32_t> fused_inputs;
282 std::set<uint32_t> shape_fusion_outputs(shape_fusion->output_indices_.begin(), shape_fusion->output_indices_.end());
283 std::set<uint32_t> post_inputs(post_node->input_indices_.begin(), post_node->input_indices_.end());
284 std::set_intersection(post_inputs.begin(), post_inputs.end(), shape_fusion_outputs.begin(),
285 shape_fusion_outputs.end(), std::inserter(fused_inputs, fused_inputs.begin()));
286 MS_CHECK_TRUE_RET(!fused_inputs.empty(), RET_ERROR);
287 MS_CHECK_TRUE_RET(shape_fusion_matrices_.find(fused_inputs.at(0)) != shape_fusion_matrices_.end(), RET_ERROR);
288
289 *shape_fusion_matrix = shape_fusion_matrices_[fused_inputs.at(0)];
290 for (size_t i = 0; i < post_node->input_indices_.size(); i++) {
291 ShapeFusionMatrix const_matrix;
292 auto input_index = post_node->input_indices_.at(i);
293 if (std::find(shape_fusion->output_indices_.begin(), shape_fusion->output_indices_.end(), input_index) !=
294 shape_fusion->output_indices_.end()) {
295 MS_CHECK_TRUE_RET(shape_fusion_matrices_.find(input_index) != shape_fusion_matrices_.end(), RET_ERROR);
296 const_matrix = shape_fusion_matrices_[input_index];
297 input_indices->push_back(input_index);
298 } else {
299 std::vector<size_t> shape = {shape_fusion_matrix->shape_matrix.size(),
300 shape_fusion_matrix->shape_matrix.front().size()};
301 auto const_tensor = src_tensors_->at(input_index);
302 MS_CHECK_TRUE_RET(const_tensor != nullptr && const_tensor->data() != nullptr, RET_ERROR);
303 if (GetFusionMatrixFromConstantTensor(const_tensor, shape, post_node->node_type_, &const_matrix) != RET_OK) {
304 MS_LOG(ERROR) << "GetMatrixFromConstantTensor failed.";
305 return RET_ERROR;
306 }
307 }
308 if (i == 0) {
309 *shape_fusion_matrix = const_matrix;
310 continue;
311 }
312 if (post_node->node_type_ == schema::PrimitiveType_Concat) {
313 shape_fusion_matrix->Append(const_matrix);
314 } else {
315 shape_fusion_matrix->Arithmetic(const_matrix, static_cast<schema::PrimitiveType>(post_node->node_type_));
316 }
317 }
318 return RET_OK;
319 }
320
UpdateShapeFusionMatrix(const LiteGraph::Node * post_node,ShapeFusionMatrix * shape_fusion_matrix)321 int ShapeFusionPass::UpdateShapeFusionMatrix(const LiteGraph::Node *post_node, ShapeFusionMatrix *shape_fusion_matrix) {
322 MS_ASSERT(post_node != nullptr && shape_fusion_matrix != nullptr);
323 switch (post_node->node_type_) {
324 case schema::PrimitiveType_Cast:
325 break;
326 case schema::PrimitiveType_Gather: {
327 auto indices_tensor = src_tensors_->at(post_node->input_indices_.at(1));
328 MS_CHECK_TRUE_RET(indices_tensor != nullptr && indices_tensor->data() != nullptr, RET_ERROR);
329 MS_CHECK_TRUE_RET(
330 indices_tensor->data_type() == kNumberTypeInt || indices_tensor->data_type() == kNumberTypeInt32, RET_ERROR);
331 std::vector<int> indices(indices_tensor->ElementsNum());
332 memcpy(indices.data(), indices_tensor->data(), indices_tensor->Size());
333 if (shape_fusion_matrix->Gather(indices) != RET_OK) {
334 MS_LOG(ERROR) << "Fuse gather failed.";
335 return RET_ERROR;
336 }
337 shape_fusion_matrix->scalar = indices_tensor->category() == CONST_SCALAR ? true : false;
338 } break;
339 case schema::PrimitiveType_Squeeze: {
340 MS_CHECK_TRUE_RET(shape_fusion_matrix->scalar == false, RET_ERROR);
341 shape_fusion_matrix->scalar = true;
342 } break;
343 case schema::PrimitiveType_Unsqueeze: {
344 MS_CHECK_TRUE_RET(shape_fusion_matrix->scalar == true, RET_ERROR);
345 shape_fusion_matrix->scalar = false;
346 } break;
347 case schema::PrimitiveType_Shape: {
348 std::vector<float> shape_vec(shape_fusion_matrix->shape_matrix.front().size(), 0);
349 shape_vec.at(shape_vec.size() - 1) = static_cast<float>(shape_fusion_matrix->shape_matrix.size());
350 shape_fusion_matrix->shape_matrix = {shape_vec};
351 shape_fusion_matrix->scalar = true;
352 } break;
353 default:
354 MS_LOG(WARNING) << "Unsupported to fuse op: " << post_node->node_type_;
355 return RET_ERROR;
356 }
357 return RET_OK;
358 }
359
GetFusionMatrixFromConstantTensor(const lite::Tensor * tensor,const std::vector<size_t> & shape,int node_type,ShapeFusionMatrix * constant_matrix)360 int ShapeFusionPass::GetFusionMatrixFromConstantTensor(const lite::Tensor *tensor, const std::vector<size_t> &shape,
361 int node_type, ShapeFusionMatrix *constant_matrix) {
362 MS_ASSERT(tensor != nullptr && tensor->data() != nullptr && constant_matrix != nullptr);
363 MS_CHECK_TRUE_RET(tensor->data_type() == kNumberTypeInt || tensor->data_type() == kNumberTypeInt32, RET_ERROR);
364 std::vector<int> value(tensor->ElementsNum());
365 memcpy(value.data(), tensor->data(), tensor->Size());
366 std::vector<std::vector<float>> shape_matrix;
367 switch (node_type) {
368 case schema::PrimitiveType_AddFusion:
369 case schema::PrimitiveType_SubFusion: {
370 std::vector<float> row_vec(shape.at(1));
371 if (value.size() == shape.at(0)) {
372 std::transform(value.begin(), value.end(), std::back_inserter(shape_matrix), [&row_vec](int ele) {
373 row_vec.at(row_vec.size() - 1) = static_cast<float>(ele);
374 return row_vec;
375 });
376 } else {
377 MS_CHECK_TRUE_RET(value.size() == 1, RET_ERROR);
378 row_vec.at(row_vec.size() - 1) = static_cast<float>(value.at(0));
379 shape_matrix = std::vector<std::vector<float>>(shape.at(0), row_vec);
380 }
381 } break;
382 case schema::PrimitiveType_MulFusion:
383 case schema::PrimitiveType_DivFusion: {
384 if (value.size() == shape.at(0)) {
385 std::transform(value.begin(), value.end(), std::back_inserter(shape_matrix), [&shape](int ele) {
386 std::vector<float> row_vec(shape.at(1), static_cast<float>(ele));
387 return row_vec;
388 });
389 } else {
390 MS_CHECK_TRUE_RET(value.size() == 1, RET_ERROR);
391 std::vector<float> row_vec(shape.at(1), static_cast<float>(value.at(0)));
392 shape_matrix = std::vector<std::vector<float>>(shape.at(0), row_vec);
393 }
394 } break;
395 case schema::PrimitiveType_Concat: {
396 std::vector<float> row_vec(shape.at(1));
397 std::transform(value.begin(), value.end(), std::back_inserter(shape_matrix), [&row_vec](int ele) {
398 row_vec.at(row_vec.size() - 1) = static_cast<float>(ele);
399 return row_vec;
400 });
401 } break;
402 default:
403 MS_LOG(ERROR) << "Unsupported to generate constant shape matrix for node type: " << node_type;
404 return RET_ERROR;
405 }
406 constant_matrix->shape_matrix = shape_matrix;
407 return RET_OK;
408 }
409 } // namespace mindspore::lite
410 #endif
411