1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/litert/runtime_packed_node_pass.h"
17 #include "nnacl/op_base.h"
18 #include "nnacl/matmul_parameter.h"
19 #include "nnacl/nnacl_kernel.h"
20 #include "nnacl/kernel/matmul_struct.h"
21 #include "common/string_utils.h"
22
23 using RecoveryWeightFunc = void (*)(void *, void *, int, int, bool);
24 namespace mindspore {
25 namespace {
26 constexpr size_t kFlatbuffersBuilderInitSize = 1024;
27 constexpr auto kActivationType = "activation_type";
28 constexpr auto kTransposeA = "transpose_a";
29 constexpr auto kTransposeB = "transpose_b";
30 constexpr auto kArm64SimdDot = "ARM64SIMD_DOT";
31 const std::vector<std::string> kAttrString = {"activation_type", "transpose_a", "transpose_b", "b_batch", "col", "deep",
32 "col_align", "deep_align"};
33 } // namespace
34
35 namespace lite {
~PackedNodePass()36 PackedNodePass::~PackedNodePass() {
37 for (auto &pack_info : node_pack_info_map_) {
38 delete pack_info.second;
39 }
40 node_pack_info_map_.clear();
41 }
42
Run(Model * model,const std::vector<Tensor * > & tensors)43 void PackedNodePass::Run(Model *model, const std::vector<Tensor *> &tensors) {
44 CHECK_NULL_RETURN_VOID(model);
45 LiteModel *lite_model = reinterpret_cast<LiteModel *>(model);
46 CHECK_NULL_RETURN_VOID(lite_model);
47 for (auto &node : model->graph_.all_nodes_) {
48 MS_ASSERT(node != nullptr);
49 if (node->node_type_ != static_cast<int>(schema::PrimitiveType_Custom)) {
50 continue;
51 }
52 auto *primitive = reinterpret_cast<const schema::Primitive *>(node->primitive_);
53 if (primitive == nullptr) {
54 MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
55 return;
56 }
57 auto custom = primitive->value_as_Custom();
58 if (custom == nullptr || custom->type() == nullptr) {
59 MS_LOG(ERROR) << "Custom node is nullptr";
60 return;
61 }
62 auto custom_type = custom->type()->str();
63 if (custom_type != "MatmulFusionPacked") {
64 continue;
65 }
66 flatbuffers::FlatBufferBuilder fbb(kFlatbuffersBuilderInitSize);
67
68 auto custom_attr = custom->attr();
69 std::map<std::string, std::string> attr_map;
70 for (uint32_t i = 0; i < custom_attr->size(); ++i) {
71 auto attr = custom_attr->Get(i);
72 auto attr_key = attr->name()->str();
73 auto data_bytes = attr->data();
74 uint32_t data_size = data_bytes->size();
75 std::string attr_value;
76 for (uint32_t j = 0; j < data_size; j++) {
77 attr_value.push_back(static_cast<char>(data_bytes->Get(j)));
78 }
79 attr_map[attr_key] = attr_value;
80 }
81 for (auto &str : kAttrString) {
82 if (attr_map.find(str) == attr_map.end() || !IsStrNumeric(attr_map[str])) {
83 MS_LOG(ERROR) << "Custom attr error.";
84 return;
85 }
86 }
87 auto val_offset = schema::CreateMatMulFusion(
88 fbb, static_cast<bool>(std::stoi(attr_map[kTransposeA])), static_cast<bool>(std::stoi(attr_map[kTransposeB])),
89 static_cast<schema::ActivationType>(std::stoi(attr_map[kActivationType])));
90 auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_MatMulFusion, val_offset.o);
91 fbb.Finish(prim_offset);
92 void *prim = malloc(fbb.GetSize());
93 if (prim == nullptr) {
94 MS_LOG(ERROR) << "malloc primitive failed.";
95 return;
96 }
97 (void)memcpy(prim, reinterpret_cast<void *>(fbb.GetBufferPointer()), fbb.GetSize());
98 auto custom_primitive = flatbuffers::GetRoot<schema::Primitive>(prim);
99 fbb.Clear();
100 PackInfo *pack_info = new (std::nothrow) PackInfo();
101 if (pack_info == nullptr) {
102 free(prim);
103 MS_LOG(ERROR) << "new PackInfo failed.";
104 return;
105 }
106 lite_model->node_bufs_.push_back(prim);
107 node->primitive_ = custom_primitive;
108 pack_info->is_packed_ = true;
109 pack_info->b_batch_ = std::stoi(attr_map["b_batch"]);
110 pack_info->col_ = std::stoi(attr_map["col"]);
111 pack_info->deep_ = std::stoi(attr_map["deep"]);
112 pack_info->col_align_ = std::stoi(attr_map["col_align"]);
113 pack_info->deep_align_ = std::stoi(attr_map["deep_align"]);
114 pack_info->b_transpose_ = static_cast<bool>(std::stoi(attr_map[kTransposeB]));
115 pack_info->cpu_option_ = attr_map["cpu_option"];
116 AddNodePackInfo(node->name_, pack_info);
117 if (node->quant_type_ == static_cast<int>(schema::QuantType_QUANT_DYNAMIC)) {
118 pack_info->weight_sums_index_ = static_cast<int>(node->input_indices_.back());
119 node->input_indices_.pop_back();
120 if (!(lite_model->keep_model_buf())) {
121 auto index = pack_info->weight_sums_index_;
122 if (index < 0 || index > static_cast<int>(tensors.size())) {
123 free(prim);
124 MS_LOG(ERROR) << "weight sums tensor index is error.";
125 return;
126 }
127 auto tensor = tensors[static_cast<size_t>(index)];
128 CopyWeightBiasSumsTensor(tensor);
129 }
130 }
131
132 node->node_type_ = static_cast<int>(schema::PrimitiveType_MatMulFusion);
133 }
134 }
135
CopyWeightBiasSumsTensor(Tensor * tensor)136 void PackedNodePass::CopyWeightBiasSumsTensor(Tensor *tensor) {
137 if (!tensor->IsConst() && tensor->data() != nullptr) {
138 return;
139 }
140 if (!tensor->IsConst() || tensor->own_data()) {
141 return;
142 }
143 if (tensor->data_type() == kObjectTypeTensorType) {
144 MS_ASSERT(tensor->data() == nullptr);
145 } else {
146 auto copy_tensor = Tensor::CopyTensor(*tensor, true);
147 if (copy_tensor == nullptr) {
148 MS_LOG(ERROR) << "Copy tensor failed";
149 return;
150 }
151 tensor->FreeData();
152 tensor->set_data(copy_tensor->data());
153 tensor->set_own_data(true);
154 copy_tensor->set_data(nullptr);
155 delete copy_tensor;
156 }
157 }
158
MatmulDynamicSdotInt8Unpack(void * src,void * dst,int row,int col,bool transpose)159 void MatmulDynamicSdotInt8Unpack(void *src, void *dst, int row, int col, bool transpose) {
160 auto src_int8 = static_cast<int8_t *>(src);
161 auto dst_int8 = static_cast<int8_t *>(dst);
162 if (!transpose) {
163 // RowMajor2Col4x16MajorInt8
164 int row_4 = UP_ROUND(row, C4NUM);
165 int stride = C16NUM * C4NUM;
166 for (int r = 0; r < row_4; ++r) {
167 for (int c = 0; c < col; ++c) {
168 int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM;
169 if (r < row) {
170 int src_idx = r * col + c;
171 src_int8[src_idx] = dst_int8[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM];
172 }
173 }
174 }
175 } else {
176 int temp = row;
177 row = col;
178 col = temp;
179 // RowMajor2Row4x16MajorInt8
180 int col4 = UP_ROUND(col, C4NUM);
181 for (int r = 0; r < row; r++) {
182 int rd16 = r / C16NUM;
183 int rm16 = r % C16NUM;
184 for (int c = 0; c < col; c++) {
185 int cd4 = c / C4NUM;
186 int cm4 = c % C4NUM;
187 int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4;
188 int src_index = r * col + c;
189 src_int8[src_index] = dst_int8[dst_index];
190 }
191 }
192 }
193 }
194
MatmulFp32BaseUnpack(void * src,void * dst,int row,int col,bool transpose)195 void MatmulFp32BaseUnpack(void *src, void *dst, int row, int col, bool transpose) {
196 if (!transpose) {
197 // RowMajor2Row8MajorParallel
198 auto src_r = static_cast<float *>(src);
199 auto dst_r = static_cast<float *>(dst);
200 for (int r = 0; r < row; r++) {
201 float *src_c = src_r + r * col;
202 int c = 0;
203 for (; c < col; c++) {
204 int cd8 = c / C8NUM;
205 int cm8 = c % C8NUM;
206 src_c[c] = dst_r[cd8 * C8NUM * row + r * C8NUM + cm8];
207 }
208 }
209 return;
210 }
211 // RowMajor2Col8MajorParallel
212 auto src_r = static_cast<float *>(src);
213 auto dst_r = static_cast<float *>(dst);
214 int row8 = row / C8NUM * C8NUM;
215 int col_skip = col / C4NUM * C4NUM;
216 int skip_size = C4NUM;
217
218 int ri = 0;
219 for (; ri < row8; ri += C8NUM) {
220 int ci = 0;
221 for (; ci < col_skip; ci += skip_size) {
222 float *src_c = src_r + ci;
223 float *dst_c = dst_r + ci * C8NUM;
224 for (int tr = 0; tr < C8NUM; tr++) {
225 for (int tc = 0; tc < C4NUM; tc++) {
226 src_c[tr * col + tc] = dst_c[tc * C8NUM + tr];
227 }
228 }
229 }
230 for (; ci < col; ci++) {
231 float *src_c = src_r + ci;
232 float *dst_c = dst_r + ci * C8NUM;
233 for (int i = 0; i < C8NUM; i++) {
234 src_c[i * col] = dst_c[i];
235 }
236 }
237 src_r += C8NUM * col;
238 dst_r += C8NUM * col;
239 }
240 for (; ri < row; ri++, src_r += col, dst_r++) {
241 for (int i = 0; i < col; i++) {
242 src_r[i] = dst_r[i * C8NUM];
243 }
244 }
245 }
246
GetRecoveryWeightFunc(const int quant_type,const TypeId data_type,const int node_type,const std::string & cpu_option)247 RecoveryWeightFunc GetRecoveryWeightFunc(const int quant_type, const TypeId data_type, const int node_type,
248 const std::string &cpu_option) {
249 if (cpu_option == kArm64SimdDot && node_type == static_cast<int>(schema::PrimitiveType_MatMulFusion) &&
250 quant_type == static_cast<int>(schema::QuantType_QUANT_DYNAMIC) && data_type == kNumberTypeInt8) {
251 return MatmulDynamicSdotInt8Unpack;
252 }
253
254 if (cpu_option == kArm64SimdDot && node_type == static_cast<int>(schema::PrimitiveType_MatMulFusion) &&
255 data_type == kNumberTypeFloat32) {
256 return MatmulFp32BaseUnpack;
257 }
258 return nullptr;
259 }
260
PackedMatmulKernelExec(kernel::KernelExec * kernel_exec,const std::vector<Tensor * > & tensors)261 int PackedMatmulKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Tensor *> &tensors) {
262 auto pack_info = PackedNodePass::GetInstance().GetNodePackInfo(kernel_exec->name());
263 if (pack_info == nullptr) {
264 return RET_OK;
265 }
266 MS_CHECK_TRUE_MSG(kernel_exec->in_tensors().size() >= kInputSize1, lite::RET_ERROR,
267 "kernel doesn't have weight tensor.");
268 auto dst_tensor = kernel_exec->in_tensors()[SECOND_INPUT];
269 auto kernel = reinterpret_cast<Kernel *>(kernel_exec->kernel());
270 MS_CHECK_TRUE_MSG(kernel != nullptr, lite::RET_NULL_PTR, "kernel is nullptr.");
271 MS_CHECK_TRUE_MSG(kernel_exec->op_parameter() != nullptr, lite::RET_NULL_PTR, "kernel parameter is nullptr.");
272 auto param = reinterpret_cast<MatMulParameter *>(kernel_exec->op_parameter());
273 const KernelBase *kernel_base = reinterpret_cast<const nnacl::NNACLKernel *>(kernel_exec->kernel())->Kernel();
274 if (dst_tensor->data_type() == kNumberTypeFloat32) {
275 const MatmulStruct *matmul = reinterpret_cast<const MatmulStruct *>(kernel_base);
276 if (matmul->matmul_type_ == kNotImplemented) {
277 return RecoveryPackedWeight(dst_tensor, static_cast<int>(kernel->quant_type()), dst_tensor->data_type(),
278 static_cast<int>(schema::PrimitiveType_MatMulFusion), *pack_info);
279 }
280 }
281
282 if (dst_tensor->data_type() == kNumberTypeInt8 && param->matmul_type_ != kMatmulDynamicSdotInt8Cpu &&
283 pack_info->cpu_option_ == kArm64SimdDot) {
284 return RecoveryPackedWeight(dst_tensor, static_cast<int>(kernel->quant_type()), dst_tensor->data_type(),
285 static_cast<int>(schema::PrimitiveType_MatMulFusion), *pack_info);
286 }
287
288 auto lite_kernel = static_cast<kernel::LiteKernel *>(kernel);
289 lite::Tensor *weight_sums = nullptr;
290 auto index = pack_info->weight_sums_index_;
291 if (index >= 0 && static_cast<size_t>(index) < tensors.size()) {
292 weight_sums = tensors.at(index);
293 }
294 return lite_kernel->PreparePackedWeight(weight_sums);
295 }
296
RecoveryPackedWeight(Tensor * weight,const int quant_type,const TypeId data_type,const int node_type,const PackInfo & pack_info)297 int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data_type, const int node_type,
298 const PackInfo &pack_info) {
299 auto recovery_func = GetRecoveryWeightFunc(quant_type, data_type, node_type, pack_info.cpu_option_);
300 if (recovery_func == nullptr) {
301 MS_LOG(ERROR) << "unsupported recovery func.";
302 return RET_NULL_PTR;
303 }
304 void *unpack_data = malloc(weight->Size());
305 if (unpack_data == nullptr) {
306 MS_LOG(ERROR) << "malloc unpack_data failed.";
307 return RET_NULL_PTR;
308 }
309 void *pack_b_ptr = weight->data();
310 for (int i = 0; i < pack_info.b_batch_; i++) {
311 void *current_weight;
312 void *current_b_pack;
313 if (weight->data_type() == kNumberTypeInt8) {
314 current_weight = static_cast<void *>(static_cast<int8_t *>(unpack_data) + i * pack_info.deep_ * pack_info.col_);
315 current_b_pack =
316 static_cast<void *>(static_cast<int8_t *>(pack_b_ptr) + i * pack_info.col_align_ * pack_info.deep_align_);
317 } else if (weight->data_type() == kNumberTypeFloat32) {
318 current_weight = static_cast<void *>(static_cast<float *>(unpack_data) + i * pack_info.deep_ * pack_info.col_);
319 current_b_pack =
320 static_cast<void *>(static_cast<float *>(pack_b_ptr) + i * pack_info.col_align_ * pack_info.deep_);
321 } else {
322 free(unpack_data);
323 MS_LOG(ERROR) << "unsupported data type.";
324 return RET_ERROR;
325 }
326 recovery_func(current_weight, current_b_pack, pack_info.deep_, pack_info.col_, pack_info.b_transpose_);
327 }
328 weight->FreeData();
329 weight->set_data(unpack_data);
330 return RET_OK;
331 }
332
PackKernelExec(kernel::KernelExec * kernel_exec,const std::vector<Tensor * > & tensors)333 int PackKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Tensor *> &tensors) {
334 if (kernel_exec->type() == schema::PrimitiveType_MatMulFusion) {
335 return PackedMatmulKernelExec(kernel_exec, tensors);
336 }
337 return RET_OK;
338 }
339 } // namespace lite
340 } // namespace mindspore
341