• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_LITERT_RUNTIME_PACKED_NODE_PASS_
18 #define MINDSPORE_LITE_SRC_LITERT_RUNTIME_PACKED_NODE_PASS_
19 
20 #include <string>
21 #include <map>
22 #include <vector>
23 #include "src/litert/lite_model.h"
24 #include "src/tensor.h"
25 #include "src/executor/kernel_exec.h"
26 
27 namespace mindspore {
28 namespace lite {
29 struct PackInfo {
30   bool is_packed_{false};
31   int weight_sums_index_{-1};
32   int b_batch_;
33   int deep_;
34   int col_;
35   int deep_align_;
36   int col_align_;
37   bool b_transpose_{false};
38   std::string cpu_option_;
39 };
40 
41 class PackedNodePass {
42  public:
GetInstance()43   static PackedNodePass &GetInstance() {
44     static PackedNodePass instance{};
45     return instance;
46   }
47 
GetNodePackInfo(const std::string & node_name)48   PackInfo *GetNodePackInfo(const std::string &node_name) {
49     if (this->node_pack_info_map_.find(node_name) == this->node_pack_info_map_.end()) {
50       return nullptr;
51     }
52     return this->node_pack_info_map_[node_name];
53   }
54   void Run(Model *model, const std::vector<Tensor *> &tensors);
55   void CopyWeightBiasSumsTensor(Tensor *tensor);
56 
57  protected:
AddNodePackInfo(const std::string & node_name,PackInfo * pack_info)58   void AddNodePackInfo(const std::string &node_name, PackInfo *pack_info) {
59     if (this->node_pack_info_map_.find(node_name) != this->node_pack_info_map_.end()) {
60       MS_LOG(WARNING) << "Key conflict when add weight sums index.";
61     }
62     this->node_pack_info_map_[node_name] = pack_info;
63   }
64 
65  private:
66   PackedNodePass() = default;
67   ~PackedNodePass();
68   std::map<std::string, PackInfo *> node_pack_info_map_;
69 };
70 
71 int PackKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Tensor *> &tensors);
72 
73 // packed weight data -> unpack
74 int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data_type, const int node_type,
75                          const PackInfo &packInfo);
76 }  // namespace lite
77 }  // namespace mindspore
78 #endif  // MINDSPORE_LITE_SRC_LITERT_RUNTIME_PACKED_NODE_PASS_
79