1 /** 2 * Copyright 2021-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_SHAPE_FUSION_PASS_H_ 18 #define MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_SHAPE_FUSION_PASS_H_ 19 20 #include <map> 21 #include <vector> 22 #include <algorithm> 23 #include "src/litert/lite_model.h" 24 #include "src/litert/inner_context.h" 25 #include "src/common/tensor_util.h" 26 #include "schema/ops_generated.h" 27 #include "schema/model_generated.h" 28 29 namespace mindspore::lite { 30 #ifndef RUNTIME_PASS_CLIP 31 struct ShapeFusionMatrix { ShapeFusionMatrixShapeFusionMatrix32 ShapeFusionMatrix() {} ShapeFusionMatrixShapeFusionMatrix33 explicit ShapeFusionMatrix(size_t dim) { 34 for (size_t i = 0; i < dim; ++i) { 35 std::vector<float> row; 36 for (size_t j = 0; j < dim; ++j) { 37 row.push_back(i == j ? 1 : 0); 38 } 39 row.push_back(0); 40 shape_matrix.push_back(row); 41 } 42 } 43 GatherShapeFusionMatrix44 int Gather(const std::vector<int> &indices) { 45 auto src_matrix = shape_matrix; 46 shape_matrix.clear(); 47 for (auto idx : indices) { 48 idx = idx >= 0 ? idx : idx + static_cast<int>(src_matrix.size()); 49 MS_CHECK_TRUE_RET(idx >= 0 && idx < static_cast<int>(src_matrix.size()), RET_ERROR); 50 shape_matrix.push_back(src_matrix.at(static_cast<size_t>(idx))); 51 } 52 return RET_OK; 53 } 54 AppendShapeFusionMatrix55 void Append(const ShapeFusionMatrix &other) { 56 for (auto row : other.shape_matrix) { 57 shape_matrix.push_back(row); 58 } 59 } 60 ArithmeticShapeFusionMatrix61 void Arithmetic(const ShapeFusionMatrix &other, schema::PrimitiveType type) { 62 for (size_t i = 0; i < shape_matrix.size(); i++) { 63 for (size_t j = 0; j < shape_matrix.front().size(); j++) { 64 switch (type) { 65 case schema::PrimitiveType_AddFusion: 66 shape_matrix[i][j] += other.shape_matrix[i][j]; 67 break; 68 case schema::PrimitiveType_SubFusion: 69 shape_matrix[i][j] -= other.shape_matrix[i][j]; 70 break; 71 case schema::PrimitiveType_MulFusion: 72 shape_matrix[i][j] *= other.shape_matrix[i][j]; 73 break; 74 case schema::PrimitiveType_DivFusion: 75 shape_matrix[i][j] /= other.shape_matrix[i][j]; 76 break; 77 default: 78 break; 79 } 80 } 81 } 82 } 83 std::vector<std::vector<float>> shape_matrix; 84 bool scalar = false; 85 }; 86 #endif 87 88 class ShapeFusionPass { 89 public: ShapeFusionPass(InnerContext * ctx,LiteModel * model,std::vector<lite::Tensor * > * src_tensors)90 ShapeFusionPass(InnerContext *ctx, LiteModel *model, std::vector<lite::Tensor *> *src_tensors) 91 : context_(ctx), lite_model_(model), all_nodes_(&(model->graph_.all_nodes_)), src_tensors_(src_tensors) { 92 MS_ASSERT(model != nullptr && src_tensors != nullptr); 93 for (auto node : model->graph_.all_nodes_) { 94 for (auto input_idx : node->input_indices_) { 95 used_nodes_[input_idx].push_back(node); 96 } 97 } 98 } 99 ~ShapeFusionPass() = default; 100 Run(LiteGraph::Node * node,size_t subgraph_index)101 void Run(LiteGraph::Node *node, size_t subgraph_index) { 102 #ifndef RUNTIME_PASS_CLIP 103 // gpu does not support to run fused shape op. 104 if (context_->IsDeviceTypeEnabled(DeviceType::DT_GPU)) { 105 return; 106 } 107 if (ConvertToShapeFusion(node) != RET_OK) { 108 MS_LOG(INFO) << "Convert to built-in shape failed: " << node->name_; 109 } else if (FusePostNodes(node, subgraph_index) != RET_OK) { 110 MS_LOG(INFO) << "Fused to built-in shape failed: " << node->name_; 111 } 112 std::transform(node->output_indices_.begin(), node->output_indices_.end(), 113 std::back_inserter(shape_fusion_outputs_), 114 [&](uint32_t idx) { return this->src_tensors_->at(idx); }); 115 #endif 116 } 117 StoreStateAndReset()118 void StoreStateAndReset() { 119 #ifndef RUNTIME_PASS_CLIP 120 std::vector<lite::Tensor *> shape_fusion_outputs = shape_fusion_outputs_; 121 shape_fusion_outputs_.clear(); 122 for (auto output : shape_fusion_outputs) { 123 if (output->IsConst()) { 124 shape_fusion_outputs_.push_back(output); 125 datas_.push_back(output->data()); 126 output->set_data(nullptr); 127 output->set_category(VAR); 128 } 129 } 130 #endif 131 } 132 RestoreState()133 void RestoreState() { 134 #ifndef RUNTIME_PASS_CLIP 135 size_t count = std::min(shape_fusion_outputs_.size(), datas_.size()); 136 for (size_t i = 0; i < count; ++i) { 137 shape_fusion_outputs_[i]->set_data(datas_[i]); 138 shape_fusion_outputs_[i]->set_category(CONST_TENSOR); 139 } 140 #endif 141 } 142 143 private: 144 #ifndef RUNTIME_PASS_CLIP 145 int ConvertToShapeFusion(LiteGraph::Node *node); 146 int FusePostNodes(LiteGraph::Node *node, size_t subgraph_index); 147 Tensor *BuildTensorFromShapeFusionMatrix(const ShapeFusionMatrix &shape_fusion_matrix); 148 bool CheckArithmetic(const LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node, uint32_t input_idx); 149 bool CheckCanFused(const LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node, uint32_t input_idx, 150 size_t subgraph_index); 151 int DoFuse(LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node, std::vector<uint32_t> *input_indices, 152 size_t subgraph_index); 153 int GenerateFusedShapeFusionMatrix(LiteGraph::Node *shape_fusion, const LiteGraph::Node *post_node, 154 std::vector<uint32_t> *input_indices, ShapeFusionMatrix *shape_fusion_matrix); 155 int UpdateShapeFusionMatrix(const LiteGraph::Node *post_node, ShapeFusionMatrix *shape_fusion_matrix); 156 int GetFusionMatrixFromConstantTensor(const lite::Tensor *tensor, const std::vector<size_t> &shape, int node_type, 157 ShapeFusionMatrix *constant_matrix); 158 159 private: 160 std::map<uint32_t, ShapeFusionMatrix> shape_fusion_matrices_; 161 std::vector<lite::Tensor *> shape_fusion_outputs_; 162 std::vector<void *> datas_; 163 int is_div_ = 0; 164 #endif 165 InnerContext *context_ = nullptr; 166 LiteModel *lite_model_ = nullptr; 167 const std::vector<LiteGraph::Node *> *all_nodes_ = nullptr; 168 std::vector<lite::Tensor *> *src_tensors_ = nullptr; 169 std::map<uint32_t, std::vector<LiteGraph::Node *>> used_nodes_; 170 }; 171 } // namespace mindspore::lite 172 #endif // MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_SHAPE_FUSION_PASS_H_ 173