• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/litert/runtime_packed_node_pass.h"
17 #include "nnacl/op_base.h"
18 #include "nnacl/matmul_parameter.h"
19 #include "nnacl/nnacl_kernel.h"
20 #include "nnacl/kernel/matmul_struct.h"
21 #include "common/string_utils.h"
22 
23 using RecoveryWeightFunc = void (*)(void *, void *, int, int, bool);
24 namespace mindspore {
25 namespace {
26 constexpr size_t kFlatbuffersBuilderInitSize = 1024;
27 constexpr auto kActivationType = "activation_type";
28 constexpr auto kTransposeA = "transpose_a";
29 constexpr auto kTransposeB = "transpose_b";
30 constexpr auto kArm64SimdDot = "ARM64SIMD_DOT";
31 const std::vector<std::string> kAttrString = {"activation_type", "transpose_a", "transpose_b", "b_batch", "col", "deep",
32                                               "col_align",       "deep_align"};
33 }  // namespace
34 
35 namespace lite {
~PackedNodePass()36 PackedNodePass::~PackedNodePass() {
37   for (auto &pack_info : node_pack_info_map_) {
38     delete pack_info.second;
39   }
40   node_pack_info_map_.clear();
41 }
42 
Run(Model * model,const std::vector<Tensor * > & tensors)43 void PackedNodePass::Run(Model *model, const std::vector<Tensor *> &tensors) {
44   CHECK_NULL_RETURN_VOID(model);
45   LiteModel *lite_model = reinterpret_cast<LiteModel *>(model);
46   CHECK_NULL_RETURN_VOID(lite_model);
47   for (auto &node : model->graph_.all_nodes_) {
48     MS_ASSERT(node != nullptr);
49     if (node->node_type_ != static_cast<int>(schema::PrimitiveType_Custom)) {
50       continue;
51     }
52     auto *primitive = reinterpret_cast<const schema::Primitive *>(node->primitive_);
53     if (primitive == nullptr) {
54       MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
55       return;
56     }
57     auto custom = primitive->value_as_Custom();
58     if (custom == nullptr || custom->type() == nullptr) {
59       MS_LOG(ERROR) << "Custom node is nullptr";
60       return;
61     }
62     auto custom_type = custom->type()->str();
63     if (custom_type != "MatmulFusionPacked") {
64       continue;
65     }
66     flatbuffers::FlatBufferBuilder fbb(kFlatbuffersBuilderInitSize);
67 
68     auto custom_attr = custom->attr();
69     std::map<std::string, std::string> attr_map;
70     for (uint32_t i = 0; i < custom_attr->size(); ++i) {
71       auto attr = custom_attr->Get(i);
72       auto attr_key = attr->name()->str();
73       auto data_bytes = attr->data();
74       uint32_t data_size = data_bytes->size();
75       std::string attr_value;
76       for (uint32_t j = 0; j < data_size; j++) {
77         attr_value.push_back(static_cast<char>(data_bytes->Get(j)));
78       }
79       attr_map[attr_key] = attr_value;
80     }
81     for (auto &str : kAttrString) {
82       if (attr_map.find(str) == attr_map.end() || !IsStrNumeric(attr_map[str])) {
83         MS_LOG(ERROR) << "Custom attr error.";
84         return;
85       }
86     }
87     auto val_offset = schema::CreateMatMulFusion(
88       fbb, static_cast<bool>(std::stoi(attr_map[kTransposeA])), static_cast<bool>(std::stoi(attr_map[kTransposeB])),
89       static_cast<schema::ActivationType>(std::stoi(attr_map[kActivationType])));
90     auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_MatMulFusion, val_offset.o);
91     fbb.Finish(prim_offset);
92     void *prim = malloc(fbb.GetSize());
93     if (prim == nullptr) {
94       MS_LOG(ERROR) << "malloc primitive failed.";
95       return;
96     }
97     (void)memcpy(prim, reinterpret_cast<void *>(fbb.GetBufferPointer()), fbb.GetSize());
98     auto custom_primitive = flatbuffers::GetRoot<schema::Primitive>(prim);
99     fbb.Clear();
100     PackInfo *pack_info = new (std::nothrow) PackInfo();
101     if (pack_info == nullptr) {
102       free(prim);
103       MS_LOG(ERROR) << "new PackInfo failed.";
104       return;
105     }
106     lite_model->node_bufs_.push_back(prim);
107     node->primitive_ = custom_primitive;
108     pack_info->is_packed_ = true;
109     pack_info->b_batch_ = std::stoi(attr_map["b_batch"]);
110     pack_info->col_ = std::stoi(attr_map["col"]);
111     pack_info->deep_ = std::stoi(attr_map["deep"]);
112     pack_info->col_align_ = std::stoi(attr_map["col_align"]);
113     pack_info->deep_align_ = std::stoi(attr_map["deep_align"]);
114     pack_info->b_transpose_ = static_cast<bool>(std::stoi(attr_map[kTransposeB]));
115     pack_info->cpu_option_ = attr_map["cpu_option"];
116     AddNodePackInfo(node->name_, pack_info);
117     if (node->quant_type_ == static_cast<int>(schema::QuantType_QUANT_DYNAMIC)) {
118       pack_info->weight_sums_index_ = static_cast<int>(node->input_indices_.back());
119       node->input_indices_.pop_back();
120       if (!(lite_model->keep_model_buf())) {
121         auto index = pack_info->weight_sums_index_;
122         if (index < 0 || index > static_cast<int>(tensors.size())) {
123           free(prim);
124           MS_LOG(ERROR) << "weight sums tensor index is error.";
125           return;
126         }
127         auto tensor = tensors[static_cast<size_t>(index)];
128         CopyWeightBiasSumsTensor(tensor);
129       }
130     }
131 
132     node->node_type_ = static_cast<int>(schema::PrimitiveType_MatMulFusion);
133   }
134 }
135 
CopyWeightBiasSumsTensor(Tensor * tensor)136 void PackedNodePass::CopyWeightBiasSumsTensor(Tensor *tensor) {
137   if (!tensor->IsConst() && tensor->data() != nullptr) {
138     return;
139   }
140   if (!tensor->IsConst() || tensor->own_data()) {
141     return;
142   }
143   if (tensor->data_type() == kObjectTypeTensorType) {
144     MS_ASSERT(tensor->data() == nullptr);
145   } else {
146     auto copy_tensor = Tensor::CopyTensor(*tensor, true);
147     if (copy_tensor == nullptr) {
148       MS_LOG(ERROR) << "Copy tensor failed";
149       return;
150     }
151     tensor->FreeData();
152     tensor->set_data(copy_tensor->data());
153     tensor->set_own_data(true);
154     copy_tensor->set_data(nullptr);
155     delete copy_tensor;
156   }
157 }
158 
MatmulDynamicSdotInt8Unpack(void * src,void * dst,int row,int col,bool transpose)159 void MatmulDynamicSdotInt8Unpack(void *src, void *dst, int row, int col, bool transpose) {
160   auto src_int8 = static_cast<int8_t *>(src);
161   auto dst_int8 = static_cast<int8_t *>(dst);
162   if (!transpose) {
163     // RowMajor2Col4x16MajorInt8
164     int row_4 = UP_ROUND(row, C4NUM);
165     int stride = C16NUM * C4NUM;
166     for (int r = 0; r < row_4; ++r) {
167       for (int c = 0; c < col; ++c) {
168         int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM;
169         if (r < row) {
170           int src_idx = r * col + c;
171           src_int8[src_idx] = dst_int8[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM];
172         }
173       }
174     }
175   } else {
176     int temp = row;
177     row = col;
178     col = temp;
179     // RowMajor2Row4x16MajorInt8
180     int col4 = UP_ROUND(col, C4NUM);
181     for (int r = 0; r < row; r++) {
182       int rd16 = r / C16NUM;
183       int rm16 = r % C16NUM;
184       for (int c = 0; c < col; c++) {
185         int cd4 = c / C4NUM;
186         int cm4 = c % C4NUM;
187         int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4;
188         int src_index = r * col + c;
189         src_int8[src_index] = dst_int8[dst_index];
190       }
191     }
192   }
193 }
194 
MatmulFp32BaseUnpack(void * src,void * dst,int row,int col,bool transpose)195 void MatmulFp32BaseUnpack(void *src, void *dst, int row, int col, bool transpose) {
196   if (!transpose) {
197     // RowMajor2Row8MajorParallel
198     auto src_r = static_cast<float *>(src);
199     auto dst_r = static_cast<float *>(dst);
200     for (int r = 0; r < row; r++) {
201       float *src_c = src_r + r * col;
202       int c = 0;
203       for (; c < col; c++) {
204         int cd8 = c / C8NUM;
205         int cm8 = c % C8NUM;
206         src_c[c] = dst_r[cd8 * C8NUM * row + r * C8NUM + cm8];
207       }
208     }
209     return;
210   }
211   // RowMajor2Col8MajorParallel
212   auto src_r = static_cast<float *>(src);
213   auto dst_r = static_cast<float *>(dst);
214   int row8 = row / C8NUM * C8NUM;
215   int col_skip = col / C4NUM * C4NUM;
216   int skip_size = C4NUM;
217 
218   int ri = 0;
219   for (; ri < row8; ri += C8NUM) {
220     int ci = 0;
221     for (; ci < col_skip; ci += skip_size) {
222       float *src_c = src_r + ci;
223       float *dst_c = dst_r + ci * C8NUM;
224       for (int tr = 0; tr < C8NUM; tr++) {
225         for (int tc = 0; tc < C4NUM; tc++) {
226           src_c[tr * col + tc] = dst_c[tc * C8NUM + tr];
227         }
228       }
229     }
230     for (; ci < col; ci++) {
231       float *src_c = src_r + ci;
232       float *dst_c = dst_r + ci * C8NUM;
233       for (int i = 0; i < C8NUM; i++) {
234         src_c[i * col] = dst_c[i];
235       }
236     }
237     src_r += C8NUM * col;
238     dst_r += C8NUM * col;
239   }
240   for (; ri < row; ri++, src_r += col, dst_r++) {
241     for (int i = 0; i < col; i++) {
242       src_r[i] = dst_r[i * C8NUM];
243     }
244   }
245 }
246 
GetRecoveryWeightFunc(const int quant_type,const TypeId data_type,const int node_type,const std::string & cpu_option)247 RecoveryWeightFunc GetRecoveryWeightFunc(const int quant_type, const TypeId data_type, const int node_type,
248                                          const std::string &cpu_option) {
249   if (cpu_option == kArm64SimdDot && node_type == static_cast<int>(schema::PrimitiveType_MatMulFusion) &&
250       quant_type == static_cast<int>(schema::QuantType_QUANT_DYNAMIC) && data_type == kNumberTypeInt8) {
251     return MatmulDynamicSdotInt8Unpack;
252   }
253 
254   if (cpu_option == kArm64SimdDot && node_type == static_cast<int>(schema::PrimitiveType_MatMulFusion) &&
255       data_type == kNumberTypeFloat32) {
256     return MatmulFp32BaseUnpack;
257   }
258   return nullptr;
259 }
260 
PackedMatmulKernelExec(kernel::KernelExec * kernel_exec,const std::vector<Tensor * > & tensors)261 int PackedMatmulKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Tensor *> &tensors) {
262   auto pack_info = PackedNodePass::GetInstance().GetNodePackInfo(kernel_exec->name());
263   if (pack_info == nullptr) {
264     return RET_OK;
265   }
266   MS_CHECK_TRUE_MSG(kernel_exec->in_tensors().size() >= kInputSize1, lite::RET_ERROR,
267                     "kernel doesn't have weight tensor.");
268   auto dst_tensor = kernel_exec->in_tensors()[SECOND_INPUT];
269   auto kernel = reinterpret_cast<Kernel *>(kernel_exec->kernel());
270   MS_CHECK_TRUE_MSG(kernel != nullptr, lite::RET_NULL_PTR, "kernel is nullptr.");
271   MS_CHECK_TRUE_MSG(kernel_exec->op_parameter() != nullptr, lite::RET_NULL_PTR, "kernel parameter is nullptr.");
272   auto param = reinterpret_cast<MatMulParameter *>(kernel_exec->op_parameter());
273   const KernelBase *kernel_base = reinterpret_cast<const nnacl::NNACLKernel *>(kernel_exec->kernel())->Kernel();
274   if (dst_tensor->data_type() == kNumberTypeFloat32) {
275     const MatmulStruct *matmul = reinterpret_cast<const MatmulStruct *>(kernel_base);
276     if (matmul->matmul_type_ == kNotImplemented) {
277       return RecoveryPackedWeight(dst_tensor, static_cast<int>(kernel->quant_type()), dst_tensor->data_type(),
278                                   static_cast<int>(schema::PrimitiveType_MatMulFusion), *pack_info);
279     }
280   }
281 
282   if (dst_tensor->data_type() == kNumberTypeInt8 && param->matmul_type_ != kMatmulDynamicSdotInt8Cpu &&
283       pack_info->cpu_option_ == kArm64SimdDot) {
284     return RecoveryPackedWeight(dst_tensor, static_cast<int>(kernel->quant_type()), dst_tensor->data_type(),
285                                 static_cast<int>(schema::PrimitiveType_MatMulFusion), *pack_info);
286   }
287 
288   auto lite_kernel = static_cast<kernel::LiteKernel *>(kernel);
289   lite::Tensor *weight_sums = nullptr;
290   auto index = pack_info->weight_sums_index_;
291   if (index >= 0 && static_cast<size_t>(index) < tensors.size()) {
292     weight_sums = tensors.at(index);
293   }
294   return lite_kernel->PreparePackedWeight(weight_sums);
295 }
296 
RecoveryPackedWeight(Tensor * weight,const int quant_type,const TypeId data_type,const int node_type,const PackInfo & pack_info)297 int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data_type, const int node_type,
298                          const PackInfo &pack_info) {
299   auto recovery_func = GetRecoveryWeightFunc(quant_type, data_type, node_type, pack_info.cpu_option_);
300   if (recovery_func == nullptr) {
301     MS_LOG(ERROR) << "unsupported recovery func.";
302     return RET_NULL_PTR;
303   }
304   void *unpack_data = malloc(weight->Size());
305   if (unpack_data == nullptr) {
306     MS_LOG(ERROR) << "malloc unpack_data failed.";
307     return RET_NULL_PTR;
308   }
309   void *pack_b_ptr = weight->data();
310   for (int i = 0; i < pack_info.b_batch_; i++) {
311     void *current_weight;
312     void *current_b_pack;
313     if (weight->data_type() == kNumberTypeInt8) {
314       current_weight = static_cast<void *>(static_cast<int8_t *>(unpack_data) + i * pack_info.deep_ * pack_info.col_);
315       current_b_pack =
316         static_cast<void *>(static_cast<int8_t *>(pack_b_ptr) + i * pack_info.col_align_ * pack_info.deep_align_);
317     } else if (weight->data_type() == kNumberTypeFloat32) {
318       current_weight = static_cast<void *>(static_cast<float *>(unpack_data) + i * pack_info.deep_ * pack_info.col_);
319       current_b_pack =
320         static_cast<void *>(static_cast<float *>(pack_b_ptr) + i * pack_info.col_align_ * pack_info.deep_);
321     } else {
322       free(unpack_data);
323       MS_LOG(ERROR) << "unsupported data type.";
324       return RET_ERROR;
325     }
326     recovery_func(current_weight, current_b_pack, pack_info.deep_, pack_info.col_, pack_info.b_transpose_);
327   }
328   weight->FreeData();
329   weight->set_data(unpack_data);
330   return RET_OK;
331 }
332 
PackKernelExec(kernel::KernelExec * kernel_exec,const std::vector<Tensor * > & tensors)333 int PackKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Tensor *> &tensors) {
334   if (kernel_exec->type() == schema::PrimitiveType_MatMulFusion) {
335     return PackedMatmulKernelExec(kernel_exec, tensors);
336   }
337   return RET_OK;
338 }
339 }  // namespace lite
340 }  // namespace mindspore
341