1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_LITERT_RUNTIME_PACKED_NODE_PASS_ 18 #define MINDSPORE_LITE_SRC_LITERT_RUNTIME_PACKED_NODE_PASS_ 19 20 #include <string> 21 #include <map> 22 #include <vector> 23 #include "src/litert/lite_model.h" 24 #include "src/tensor.h" 25 #include "src/executor/kernel_exec.h" 26 27 namespace mindspore { 28 namespace lite { 29 struct PackInfo { 30 bool is_packed_{false}; 31 int weight_sums_index_{-1}; 32 int b_batch_; 33 int deep_; 34 int col_; 35 int deep_align_; 36 int col_align_; 37 bool b_transpose_{false}; 38 std::string cpu_option_; 39 }; 40 41 class PackedNodePass { 42 public: GetInstance()43 static PackedNodePass &GetInstance() { 44 static PackedNodePass instance{}; 45 return instance; 46 } 47 GetNodePackInfo(const std::string & node_name)48 PackInfo *GetNodePackInfo(const std::string &node_name) { 49 if (this->node_pack_info_map_.find(node_name) == this->node_pack_info_map_.end()) { 50 return nullptr; 51 } 52 return this->node_pack_info_map_[node_name]; 53 } 54 void Run(Model *model, const std::vector<Tensor *> &tensors); 55 void CopyWeightBiasSumsTensor(Tensor *tensor); 56 57 protected: AddNodePackInfo(const std::string & node_name,PackInfo * pack_info)58 void AddNodePackInfo(const std::string &node_name, PackInfo *pack_info) { 59 if (this->node_pack_info_map_.find(node_name) != this->node_pack_info_map_.end()) { 60 MS_LOG(WARNING) << "Key conflict when add weight sums index."; 61 } 62 this->node_pack_info_map_[node_name] = pack_info; 63 } 64 65 private: 66 PackedNodePass() = default; 67 ~PackedNodePass(); 68 std::map<std::string, PackInfo *> node_pack_info_map_; 69 }; 70 71 int PackKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Tensor *> &tensors); 72 73 // packed weight data -> unpack 74 int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data_type, const int node_type, 75 const PackInfo &packInfo); 76 } // namespace lite 77 } // namespace mindspore 78 #endif // MINDSPORE_LITE_SRC_LITERT_RUNTIME_PACKED_NODE_PASS_ 79