• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef RUNTIME_PASS_CLIP
18 #include "src/litert/runtime_shape_fusion_pass.h"
19 #include <set>
20 #include <queue>
21 #include <algorithm>
22 #include "include/errorcode.h"
23 #include "src/common/log_adapter.h"
24 #include "nnacl/op_base.h"
25 
26 namespace mindspore::lite {
27 namespace {
28 constexpr size_t kInitialSize = 1024;
29 }  // namespace
ConvertToShapeFusion(LiteGraph::Node * node)30 int ShapeFusionPass::ConvertToShapeFusion(LiteGraph::Node *node) {
31   MS_ASSERT(node != nullptr);
32   auto input_tensor = src_tensors_->at(node->input_indices_.front());
33   MS_CHECK_TRUE_RET(input_tensor != nullptr, RET_ERROR);
34   auto shape = input_tensor->shape();
35   if (shape.empty() || std::find(shape.begin(), shape.end(), -1) != shape.end()) {
36     MS_LOG(INFO) << "The input shape is invalid.";
37     return RET_ERROR;
38   }
39 
40   flatbuffers::FlatBufferBuilder fbb(kInitialSize);
41   auto val_offset = schema::CreateCustomDirect(fbb, "ShapeFusion");
42   auto prim_offset =
43     schema::CreatePrimitive(fbb, static_cast<schema::PrimitiveType>(PrimType::PrimType_Custom), val_offset.o);
44   fbb.Finish(prim_offset);
45   void *prim = malloc(fbb.GetSize());
46   if (prim == nullptr) {
47     MS_LOG(ERROR) << "malloc primitive failed.";
48     return RET_ERROR;
49   }
50   memcpy(prim, fbb.GetBufferPointer(), fbb.GetSize());
51   lite_model_->node_bufs_.push_back(prim);
52   fbb.Clear();
53 
54   auto shape_fusion_prim = flatbuffers::GetRoot<schema::Primitive>(prim);
55   if(shape_fusion_prim == nullptr){
56     free(prim);
57     MS_LOG(ERROR) << "shape_fusion_prim is nullptr";
58     return RET_ERROR;
59   }
60   ShapeFusionMatrix shape_fusion_matrix(shape.size());
61   if(node->output_indices_.empty()){
62     free(prim);
63     MS_LOG(ERROR) << "node->output_indices_ is empty";
64     return RET_ERROR;
65   }
66   shape_fusion_matrices_[node->output_indices_.front()] = shape_fusion_matrix;
67   auto shape_fusion_matrix_tensor = BuildTensorFromShapeFusionMatrix(shape_fusion_matrix);
68   if(shape_fusion_matrix_tensor == nullptr){
69     free(prim);
70     MS_LOG(ERROR) << "shape_fusion_matrix_tensor is nullptr";
71     return RET_ERROR;
72   }
73 
74   node->name_ += "_fusion";
75   node->primitive_ = shape_fusion_prim;
76   node->node_type_ = PrimType::PrimType_Inner_ShapeFusion;
77   node->input_indices_.push_back(src_tensors_->size());
78   src_tensors_->push_back(shape_fusion_matrix_tensor);
79   return RET_OK;
80 }
81 
BuildTensorFromShapeFusionMatrix(const ShapeFusionMatrix & shape_fusion_matrix)82 Tensor *ShapeFusionPass::BuildTensorFromShapeFusionMatrix(const ShapeFusionMatrix &shape_fusion_matrix) {
83   MS_CHECK_TRUE_RET(!shape_fusion_matrix.shape_matrix.empty(), nullptr);
84   std::vector<int> matrix_shape;
85   if (shape_fusion_matrix.shape_matrix.size() != 1 || !shape_fusion_matrix.scalar) {
86     matrix_shape.push_back(static_cast<int>(shape_fusion_matrix.shape_matrix.size()));
87   }
88   matrix_shape.push_back(static_cast<int>(shape_fusion_matrix.shape_matrix.front().size()));
89   auto tensor = new (std::nothrow) Tensor(kNumberTypeFloat32, matrix_shape, NUM_OF_FORMAT, Category::CONST_TENSOR);
90   MS_CHECK_TRUE_RET(tensor != nullptr, nullptr);
91   auto matrix_data = tensor->MutableData();
92   if (matrix_data == nullptr) {
93     MS_LOG(ERROR) << "Mutable data failed for tensor: " << tensor->tensor_name();
94     delete tensor;
95     return nullptr;
96   }
97   for (size_t row = 0; row < shape_fusion_matrix.shape_matrix.size(); row++) {
98     auto dst_data = reinterpret_cast<float *>(matrix_data) + row * shape_fusion_matrix.shape_matrix.front().size();
99     memcpy(dst_data, shape_fusion_matrix.shape_matrix.at(row).data(),
100            shape_fusion_matrix.shape_matrix.front().size() * sizeof(float));
101   }
102   return tensor;
103 }
104 
FusePostNodes(LiteGraph::Node * node,size_t subgraph_index)105 int ShapeFusionPass::FusePostNodes(LiteGraph::Node *node, size_t subgraph_index) {
106   // fuse arithmetic/concat/gather/squeeze/unsqueeze/shape/cast
107   MS_ASSERT(node != nullptr);
108   std::queue<LiteGraph::Node *> candidate_nodes;
109   auto output_index = node->output_indices_.front();
110   MS_CHECK_TRUE_RET(used_nodes_.find(output_index) != used_nodes_.end(), RET_ERROR);
111   std::vector<uint32_t> visited_outputs;
112   for (auto out_node : used_nodes_[output_index]) {
113     if (CheckCanFused(node, out_node, output_index, subgraph_index)) {
114       candidate_nodes.push(out_node);
115     }
116     visited_outputs.push_back(output_index);
117   }
118 
119   while (!candidate_nodes.empty()) {
120     auto output_node = candidate_nodes.front();
121     candidate_nodes.pop();
122     std::vector<uint32_t> used_outputs;
123     if (DoFuse(node, output_node, &used_outputs, subgraph_index) != RET_OK) {
124       MS_LOG(WARNING) << "Fused to shape fusion failed: " << output_node->name_;
125       continue;
126     }
127     // remove unused input and output
128     for (auto original_output : used_outputs) {
129       MS_CHECK_TRUE_RET(used_nodes_.find(original_output) != used_nodes_.end(), RET_ERROR);
130       if (used_nodes_[original_output].empty()) {
131         auto remove_itr = std::find(node->output_indices_.begin(), node->output_indices_.end(), original_output);
132         if (remove_itr == node->output_indices_.end()) {
133           MS_LOG(ERROR) << "can not find original output";
134           return RET_ERROR;
135         }
136         node->output_indices_.erase(remove_itr);
137         node->input_indices_.erase(node->input_indices_.begin() + (remove_itr - node->output_indices_.begin()) + 1);
138       }
139     }
140 
141     for (auto idx : node->output_indices_) {
142       MS_CHECK_TRUE_RET(used_nodes_.find(idx) != used_nodes_.end(), RET_ERROR);
143       if (std::find(visited_outputs.begin(), visited_outputs.end(), idx) != visited_outputs.end()) {
144         continue;
145       }
146       visited_outputs.push_back(idx);
147       for (auto out_node : used_nodes_[idx]) {
148         if (CheckCanFused(node, out_node, idx, subgraph_index)) {
149           candidate_nodes.push(out_node);
150         }
151       }
152     }
153   }
154   return RET_OK;
155 }
156 
CheckArithmetic(const LiteGraph::Node * shape_fusion,const LiteGraph::Node * post_node,uint32_t input_idx)157 bool ShapeFusionPass::CheckArithmetic(const LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node,
158                                       uint32_t input_idx) {
159   MS_ASSERT(shape_fusion != nullptr && post_node != nullptr);
160   auto type = post_node->node_type_;
161   if (is_div_ && type != schema::PrimitiveType_DivFusion) {
162     // couldn't fuse add/sub/mul+div or add/sub/div+mul, because it maybe change indivisible to divisible.
163     return false;
164   }
165   is_div_ |= (type == schema::PrimitiveType_DivFusion);
166   MS_CHECK_TRUE_RET(post_node->input_indices_.size() == kInputSize1, false);
167   auto input1_index =
168     post_node->input_indices_.at(0) == input_idx ? post_node->input_indices_.at(1) : post_node->input_indices_.at(0);
169   auto tensor = src_tensors_->at(input1_index);
170   MS_CHECK_TRUE_RET(tensor != nullptr, false);
171   if (tensor->IsConst()) {
172     return true;
173   }
174   auto shape_fusion_outputs = shape_fusion->output_indices_;
175   auto fused_output =
176     std::find(shape_fusion_outputs.begin(), shape_fusion_outputs.end(), input1_index) != shape_fusion_outputs.end();
177   return fused_output && (type == schema::PrimitiveType_AddFusion || type == schema::PrimitiveType_SubFusion);
178 }
179 
CheckCanFused(const LiteGraph::Node * shape_fusion,const LiteGraph::Node * post_node,uint32_t input_idx,size_t subgraph_index)180 bool ShapeFusionPass::CheckCanFused(const LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node,
181                                     uint32_t input_idx, size_t subgraph_index) {
182   MS_ASSERT(shape_fusion != nullptr && post_node != nullptr);
183   MS_CHECK_TRUE_RET(subgraph_index < lite_model_->graph_.sub_graphs_.size(), false);
184   auto subgraph = lite_model_->graph_.sub_graphs_.at(subgraph_index);
185   MS_CHECK_TRUE_RET(subgraph != nullptr, false);
186   auto &subgraph_node_indices = subgraph->node_indices_;
187   bool belong_to_current_subgraph = std::any_of(subgraph_node_indices.begin(), subgraph_node_indices.end(),
188                                                 [&](uint32_t idx) { return all_nodes_->at(idx) == post_node; });
189   if (!belong_to_current_subgraph) {
190     return false;
191   }
192   auto shape_fusion_outputs = shape_fusion->output_indices_;
193   switch (post_node->node_type_) {
194     case schema::PrimitiveType_Cast: {
195       MS_CHECK_TRUE_RET(post_node->input_indices_.size() == kInputSize1, false);
196       auto dst_type_tensor = src_tensors_->at(post_node->input_indices_.at(1));
197       MS_CHECK_TRUE_RET(dst_type_tensor != nullptr && dst_type_tensor->data() != nullptr, false);
198       auto data_type = reinterpret_cast<int *>(dst_type_tensor->data())[0];
199       return data_type == kNumberTypeInt || data_type == kNumberTypeInt32;
200     }
201     case schema::PrimitiveType_AddFusion:
202     case schema::PrimitiveType_SubFusion:
203     case schema::PrimitiveType_MulFusion:
204     case schema::PrimitiveType_DivFusion:
205       return CheckArithmetic(shape_fusion, post_node, input_idx);
206     case schema::PrimitiveType_Concat: {
207       bool is_supported =
208         std::all_of(post_node->input_indices_.begin(), post_node->input_indices_.end(), [&](uint32_t idx) {
209           auto tensor = src_tensors_->at(idx);
210           return tensor->IsConst() ||
211                  std::find(shape_fusion_outputs.begin(), shape_fusion_outputs.end(), idx) != shape_fusion_outputs.end();
212         });
213       return is_supported;
214     }
215     case schema::PrimitiveType_Gather:
216     case schema::PrimitiveType_Squeeze:
217     case schema::PrimitiveType_Unsqueeze:
218     case schema::PrimitiveType_Shape:
219       return true;
220     default:
221       break;
222   }
223   return false;
224 }
225 
DoFuse(LiteGraph::Node * shape_fusion,const LiteGraph::Node * post_node,std::vector<uint32_t> * input_indices,size_t subgraph_index)226 int ShapeFusionPass::DoFuse(LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node,
227                             std::vector<uint32_t> *input_indices, size_t subgraph_index) {
228   MS_ASSERT(shape_fusion != nullptr && post_node != nullptr && input_indices != nullptr);
229   ShapeFusionMatrix shape_fusion_matrix;
230   auto type = post_node->node_type_;
231   if (type == schema::PrimitiveType_AddFusion || type == schema::PrimitiveType_SubFusion ||
232       type == schema::PrimitiveType_MulFusion || type == schema::PrimitiveType_DivFusion ||
233       type == schema::PrimitiveType_Concat) {
234     if (GenerateFusedShapeFusionMatrix(shape_fusion, post_node, input_indices, &shape_fusion_matrix) != RET_OK) {
235       MS_LOG(WARNING) << "GenerateFusedShapeMatrix failed while fuse op: " << post_node->name_;
236       return RET_ERROR;
237     }
238   } else {
239     auto input_index = post_node->input_indices_.front();
240     MS_CHECK_TRUE_RET(shape_fusion_matrices_.find(input_index) != shape_fusion_matrices_.end(), RET_ERROR);
241     shape_fusion_matrix = shape_fusion_matrices_[input_index];
242     input_indices->push_back(input_index);
243     if (UpdateShapeFusionMatrix(post_node, &shape_fusion_matrix) != RET_OK) {
244       MS_LOG(WARNING) << "UpdateShapeMatrix failed while fuse op: " << post_node->name_;
245       return RET_ERROR;
246     }
247   }
248 
249   // generate matrix_tensor, and update input_indices and output_indices
250   auto tensor = BuildTensorFromShapeFusionMatrix(shape_fusion_matrix);
251   MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
252   shape_fusion->input_indices_.push_back(src_tensors_->size());
253   src_tensors_->push_back(tensor);
254   auto output_index = post_node->output_indices_.front();
255   shape_fusion->output_indices_.push_back(output_index);
256   shape_fusion_matrices_[output_index] = shape_fusion_matrix;
257 
258   MS_CHECK_TRUE_RET(subgraph_index < lite_model_->graph_.sub_graphs_.size(), RET_ERROR);
259   auto subgraph = lite_model_->graph_.sub_graphs_.at(subgraph_index);
260   MS_CHECK_TRUE_RET(subgraph != nullptr, RET_ERROR);
261   auto &subgraph_node_indices = subgraph->node_indices_;
262   size_t node_index = std::find(all_nodes_->begin(), all_nodes_->end(), post_node) - all_nodes_->begin();
263   MS_CHECK_TRUE_RET(node_index != all_nodes_->size(), RET_ERROR);
264   auto indice_itr = std::find(subgraph_node_indices.begin(), subgraph_node_indices.end(), node_index);
265   MS_CHECK_TRUE_RET(indice_itr != subgraph_node_indices.end(), RET_ERROR);
266   subgraph_node_indices.erase(indice_itr);
267   for (auto idx : *input_indices) {
268     MS_CHECK_TRUE_RET(used_nodes_.find(idx) != used_nodes_.end(), RET_ERROR);
269     auto &used_nodes = used_nodes_[idx];
270     auto itr = std::find(used_nodes.begin(), used_nodes.end(), post_node);
271     MS_CHECK_TRUE_RET(itr != used_nodes.end(), RET_ERROR);
272     used_nodes.erase(itr);
273   }
274   return RET_OK;
275 }
276 
GenerateFusedShapeFusionMatrix(LiteGraph::Node * shape_fusion,const LiteGraph::Node * post_node,std::vector<uint32_t> * input_indices,ShapeFusionMatrix * shape_fusion_matrix)277 int ShapeFusionPass::GenerateFusedShapeFusionMatrix(LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node,
278                                                     std::vector<uint32_t> *input_indices,
279                                                     ShapeFusionMatrix *shape_fusion_matrix) {
280   MS_ASSERT(shape_fusion != nullptr && post_node != nullptr && shape_fusion_matrix != nullptr);
281   std::vector<uint32_t> fused_inputs;
282   std::set<uint32_t> shape_fusion_outputs(shape_fusion->output_indices_.begin(), shape_fusion->output_indices_.end());
283   std::set<uint32_t> post_inputs(post_node->input_indices_.begin(), post_node->input_indices_.end());
284   std::set_intersection(post_inputs.begin(), post_inputs.end(), shape_fusion_outputs.begin(),
285                         shape_fusion_outputs.end(), std::inserter(fused_inputs, fused_inputs.begin()));
286   MS_CHECK_TRUE_RET(!fused_inputs.empty(), RET_ERROR);
287   MS_CHECK_TRUE_RET(shape_fusion_matrices_.find(fused_inputs.at(0)) != shape_fusion_matrices_.end(), RET_ERROR);
288 
289   *shape_fusion_matrix = shape_fusion_matrices_[fused_inputs.at(0)];
290   for (size_t i = 0; i < post_node->input_indices_.size(); i++) {
291     ShapeFusionMatrix const_matrix;
292     auto input_index = post_node->input_indices_.at(i);
293     if (std::find(shape_fusion->output_indices_.begin(), shape_fusion->output_indices_.end(), input_index) !=
294         shape_fusion->output_indices_.end()) {
295       MS_CHECK_TRUE_RET(shape_fusion_matrices_.find(input_index) != shape_fusion_matrices_.end(), RET_ERROR);
296       const_matrix = shape_fusion_matrices_[input_index];
297       input_indices->push_back(input_index);
298     } else {
299       std::vector<size_t> shape = {shape_fusion_matrix->shape_matrix.size(),
300                                    shape_fusion_matrix->shape_matrix.front().size()};
301       auto const_tensor = src_tensors_->at(input_index);
302       MS_CHECK_TRUE_RET(const_tensor != nullptr && const_tensor->data() != nullptr, RET_ERROR);
303       if (GetFusionMatrixFromConstantTensor(const_tensor, shape, post_node->node_type_, &const_matrix) != RET_OK) {
304         MS_LOG(ERROR) << "GetMatrixFromConstantTensor failed.";
305         return RET_ERROR;
306       }
307     }
308     if (i == 0) {
309       *shape_fusion_matrix = const_matrix;
310       continue;
311     }
312     if (post_node->node_type_ == schema::PrimitiveType_Concat) {
313       shape_fusion_matrix->Append(const_matrix);
314     } else {
315       shape_fusion_matrix->Arithmetic(const_matrix, static_cast<schema::PrimitiveType>(post_node->node_type_));
316     }
317   }
318   return RET_OK;
319 }
320 
UpdateShapeFusionMatrix(const LiteGraph::Node * post_node,ShapeFusionMatrix * shape_fusion_matrix)321 int ShapeFusionPass::UpdateShapeFusionMatrix(const LiteGraph::Node *post_node, ShapeFusionMatrix *shape_fusion_matrix) {
322   MS_ASSERT(post_node != nullptr && shape_fusion_matrix != nullptr);
323   switch (post_node->node_type_) {
324     case schema::PrimitiveType_Cast:
325       break;
326     case schema::PrimitiveType_Gather: {
327       auto indices_tensor = src_tensors_->at(post_node->input_indices_.at(1));
328       MS_CHECK_TRUE_RET(indices_tensor != nullptr && indices_tensor->data() != nullptr, RET_ERROR);
329       MS_CHECK_TRUE_RET(
330         indices_tensor->data_type() == kNumberTypeInt || indices_tensor->data_type() == kNumberTypeInt32, RET_ERROR);
331       std::vector<int> indices(indices_tensor->ElementsNum());
332       memcpy(indices.data(), indices_tensor->data(), indices_tensor->Size());
333       if (shape_fusion_matrix->Gather(indices) != RET_OK) {
334         MS_LOG(ERROR) << "Fuse gather failed.";
335         return RET_ERROR;
336       }
337       shape_fusion_matrix->scalar = indices_tensor->category() == CONST_SCALAR ? true : false;
338     } break;
339     case schema::PrimitiveType_Squeeze: {
340       MS_CHECK_TRUE_RET(shape_fusion_matrix->scalar == false, RET_ERROR);
341       shape_fusion_matrix->scalar = true;
342     } break;
343     case schema::PrimitiveType_Unsqueeze: {
344       MS_CHECK_TRUE_RET(shape_fusion_matrix->scalar == true, RET_ERROR);
345       shape_fusion_matrix->scalar = false;
346     } break;
347     case schema::PrimitiveType_Shape: {
348       std::vector<float> shape_vec(shape_fusion_matrix->shape_matrix.front().size(), 0);
349       shape_vec.at(shape_vec.size() - 1) = static_cast<float>(shape_fusion_matrix->shape_matrix.size());
350       shape_fusion_matrix->shape_matrix = {shape_vec};
351       shape_fusion_matrix->scalar = true;
352     } break;
353     default:
354       MS_LOG(WARNING) << "Unsupported to fuse op: " << post_node->node_type_;
355       return RET_ERROR;
356   }
357   return RET_OK;
358 }
359 
GetFusionMatrixFromConstantTensor(const lite::Tensor * tensor,const std::vector<size_t> & shape,int node_type,ShapeFusionMatrix * constant_matrix)360 int ShapeFusionPass::GetFusionMatrixFromConstantTensor(const lite::Tensor *tensor, const std::vector<size_t> &shape,
361                                                        int node_type, ShapeFusionMatrix *constant_matrix) {
362   MS_ASSERT(tensor != nullptr && tensor->data() != nullptr && constant_matrix != nullptr);
363   MS_CHECK_TRUE_RET(tensor->data_type() == kNumberTypeInt || tensor->data_type() == kNumberTypeInt32, RET_ERROR);
364   std::vector<int> value(tensor->ElementsNum());
365   memcpy(value.data(), tensor->data(), tensor->Size());
366   std::vector<std::vector<float>> shape_matrix;
367   switch (node_type) {
368     case schema::PrimitiveType_AddFusion:
369     case schema::PrimitiveType_SubFusion: {
370       std::vector<float> row_vec(shape.at(1));
371       if (value.size() == shape.at(0)) {
372         std::transform(value.begin(), value.end(), std::back_inserter(shape_matrix), [&row_vec](int ele) {
373           row_vec.at(row_vec.size() - 1) = static_cast<float>(ele);
374           return row_vec;
375         });
376       } else {
377         MS_CHECK_TRUE_RET(value.size() == 1, RET_ERROR);
378         row_vec.at(row_vec.size() - 1) = static_cast<float>(value.at(0));
379         shape_matrix = std::vector<std::vector<float>>(shape.at(0), row_vec);
380       }
381     } break;
382     case schema::PrimitiveType_MulFusion:
383     case schema::PrimitiveType_DivFusion: {
384       if (value.size() == shape.at(0)) {
385         std::transform(value.begin(), value.end(), std::back_inserter(shape_matrix), [&shape](int ele) {
386           std::vector<float> row_vec(shape.at(1), static_cast<float>(ele));
387           return row_vec;
388         });
389       } else {
390         MS_CHECK_TRUE_RET(value.size() == 1, RET_ERROR);
391         std::vector<float> row_vec(shape.at(1), static_cast<float>(value.at(0)));
392         shape_matrix = std::vector<std::vector<float>>(shape.at(0), row_vec);
393       }
394     } break;
395     case schema::PrimitiveType_Concat: {
396       std::vector<float> row_vec(shape.at(1));
397       std::transform(value.begin(), value.end(), std::back_inserter(shape_matrix), [&row_vec](int ele) {
398         row_vec.at(row_vec.size() - 1) = static_cast<float>(ele);
399         return row_vec;
400       });
401     } break;
402     default:
403       MS_LOG(ERROR) << "Unsupported to generate constant shape matrix for node type: " << node_type;
404       return RET_ERROR;
405   }
406   constant_matrix->shape_matrix = shape_matrix;
407   return RET_OK;
408 }
409 }  // namespace mindspore::lite
410 #endif
411