• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_SHAPE_FUSION_PASS_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_SHAPE_FUSION_PASS_H_
19 
20 #include <map>
21 #include <vector>
22 #include <algorithm>
23 #include "src/litert/lite_model.h"
24 #include "src/litert/inner_context.h"
25 #include "src/common/tensor_util.h"
26 #include "schema/ops_generated.h"
27 #include "schema/model_generated.h"
28 
29 namespace mindspore::lite {
30 #ifndef RUNTIME_PASS_CLIP
31 struct ShapeFusionMatrix {
ShapeFusionMatrixShapeFusionMatrix32   ShapeFusionMatrix() {}
ShapeFusionMatrixShapeFusionMatrix33   explicit ShapeFusionMatrix(size_t dim) {
34     for (size_t i = 0; i < dim; ++i) {
35       std::vector<float> row;
36       for (size_t j = 0; j < dim; ++j) {
37         row.push_back(i == j ? 1 : 0);
38       }
39       row.push_back(0);
40       shape_matrix.push_back(row);
41     }
42   }
43 
GatherShapeFusionMatrix44   int Gather(const std::vector<int> &indices) {
45     auto src_matrix = shape_matrix;
46     shape_matrix.clear();
47     for (auto idx : indices) {
48       idx = idx >= 0 ? idx : idx + static_cast<int>(src_matrix.size());
49       MS_CHECK_TRUE_RET(idx >= 0 && idx < static_cast<int>(src_matrix.size()), RET_ERROR);
50       shape_matrix.push_back(src_matrix.at(static_cast<size_t>(idx)));
51     }
52     return RET_OK;
53   }
54 
AppendShapeFusionMatrix55   void Append(const ShapeFusionMatrix &other) {
56     for (auto row : other.shape_matrix) {
57       shape_matrix.push_back(row);
58     }
59   }
60 
ArithmeticShapeFusionMatrix61   void Arithmetic(const ShapeFusionMatrix &other, schema::PrimitiveType type) {
62     for (size_t i = 0; i < shape_matrix.size(); i++) {
63       for (size_t j = 0; j < shape_matrix.front().size(); j++) {
64         switch (type) {
65           case schema::PrimitiveType_AddFusion:
66             shape_matrix[i][j] += other.shape_matrix[i][j];
67             break;
68           case schema::PrimitiveType_SubFusion:
69             shape_matrix[i][j] -= other.shape_matrix[i][j];
70             break;
71           case schema::PrimitiveType_MulFusion:
72             shape_matrix[i][j] *= other.shape_matrix[i][j];
73             break;
74           case schema::PrimitiveType_DivFusion:
75             shape_matrix[i][j] /= other.shape_matrix[i][j];
76             break;
77           default:
78             break;
79         }
80       }
81     }
82   }
83   std::vector<std::vector<float>> shape_matrix;
84   bool scalar = false;
85 };
86 #endif
87 
88 class ShapeFusionPass {
89  public:
ShapeFusionPass(InnerContext * ctx,LiteModel * model,std::vector<lite::Tensor * > * src_tensors)90   ShapeFusionPass(InnerContext *ctx, LiteModel *model, std::vector<lite::Tensor *> *src_tensors)
91       : context_(ctx), lite_model_(model), all_nodes_(&(model->graph_.all_nodes_)), src_tensors_(src_tensors) {
92     MS_ASSERT(model != nullptr && src_tensors != nullptr);
93     for (auto node : model->graph_.all_nodes_) {
94       for (auto input_idx : node->input_indices_) {
95         used_nodes_[input_idx].push_back(node);
96       }
97     }
98   }
99   ~ShapeFusionPass() = default;
100 
Run(LiteGraph::Node * node,size_t subgraph_index)101   void Run(LiteGraph::Node *node, size_t subgraph_index) {
102 #ifndef RUNTIME_PASS_CLIP
103     // gpu does not support to run fused shape op.
104     if (context_->IsDeviceTypeEnabled(DeviceType::DT_GPU)) {
105       return;
106     }
107     if (ConvertToShapeFusion(node) != RET_OK) {
108       MS_LOG(INFO) << "Convert to built-in shape failed: " << node->name_;
109     } else if (FusePostNodes(node, subgraph_index) != RET_OK) {
110       MS_LOG(INFO) << "Fused to built-in shape failed: " << node->name_;
111     }
112     std::transform(node->output_indices_.begin(), node->output_indices_.end(),
113                    std::back_inserter(shape_fusion_outputs_),
114                    [&](uint32_t idx) { return this->src_tensors_->at(idx); });
115 #endif
116   }
117 
StoreStateAndReset()118   void StoreStateAndReset() {
119 #ifndef RUNTIME_PASS_CLIP
120     std::vector<lite::Tensor *> shape_fusion_outputs = shape_fusion_outputs_;
121     shape_fusion_outputs_.clear();
122     for (auto output : shape_fusion_outputs) {
123       if (output->IsConst()) {
124         shape_fusion_outputs_.push_back(output);
125         datas_.push_back(output->data());
126         output->set_data(nullptr);
127         output->set_category(VAR);
128       }
129     }
130 #endif
131   }
132 
RestoreState()133   void RestoreState() {
134 #ifndef RUNTIME_PASS_CLIP
135     size_t count = std::min(shape_fusion_outputs_.size(), datas_.size());
136     for (size_t i = 0; i < count; ++i) {
137       shape_fusion_outputs_[i]->set_data(datas_[i]);
138       shape_fusion_outputs_[i]->set_category(CONST_TENSOR);
139     }
140 #endif
141   }
142 
143  private:
144 #ifndef RUNTIME_PASS_CLIP
145   int ConvertToShapeFusion(LiteGraph::Node *node);
146   int FusePostNodes(LiteGraph::Node *node, size_t subgraph_index);
147   Tensor *BuildTensorFromShapeFusionMatrix(const ShapeFusionMatrix &shape_fusion_matrix);
148   bool CheckArithmetic(const LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node, uint32_t input_idx);
149   bool CheckCanFused(const LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node, uint32_t input_idx,
150                      size_t subgraph_index);
151   int DoFuse(LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node, std::vector<uint32_t> *input_indices,
152              size_t subgraph_index);
153   int GenerateFusedShapeFusionMatrix(LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node,
154                                      std::vector<uint32_t> *input_indices, ShapeFusionMatrix *shape_fusion_matrix);
155   int UpdateShapeFusionMatrix(const LiteGraph::Node *post_node, ShapeFusionMatrix *shape_fusion_matrix);
156   int GetFusionMatrixFromConstantTensor(const lite::Tensor *tensor, const std::vector<size_t> &shape, int node_type,
157                                         ShapeFusionMatrix *constant_matrix);
158 
159  private:
160   std::map<uint32_t, ShapeFusionMatrix> shape_fusion_matrices_;
161   std::vector<lite::Tensor *> shape_fusion_outputs_;
162   std::vector<void *> datas_;
163   int is_div_ = 0;
164 #endif
165   InnerContext *context_ = nullptr;
166   LiteModel *lite_model_ = nullptr;
167   const std::vector<LiteGraph::Node *> *all_nodes_ = nullptr;
168   std::vector<lite::Tensor *> *src_tensors_ = nullptr;
169   std::map<uint32_t, std::vector<LiteGraph::Node *>> used_nodes_;
170 };
171 }  // namespace mindspore::lite
172 #endif  // MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_SHAPE_FUSION_PASS_H_
173