1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <cmath>
17 #include <string>
18 #include "src/litert/weight_decoder.h"
19 #include "src/litert/huffman_decode.h"
20 #include "tools/converter/quantizer/fse_decoder.h"
21 #include "nnacl/conv_parameter.h"
22
23 namespace mindspore::lite {
24 #ifndef WEIGHT_DECODE_CLIP
DequantWeight(lite::Tensor * input_tensor,int preferred_dim,TypeId dst_data_type)25 int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, int preferred_dim, TypeId dst_data_type) {
26 MS_ASSERT(input_tensor != nullptr);
27 if (input_tensor->quant_params().empty()) {
28 MS_LOG(ERROR) << "No quant param.";
29 return RET_ERROR;
30 }
31 if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat32) {
32 auto new_const_data = DequantData<int16_t, float>(input_tensor, preferred_dim);
33 CHECK_NULL_RETURN(new_const_data);
34 input_tensor->FreeData();
35 input_tensor->set_data(new_const_data);
36 input_tensor->set_own_data(true);
37 input_tensor->set_data_type(dst_data_type);
38 } else if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat16) {
39 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
40 auto new_const_data = DequantData<int16_t, float16_t>(input_tensor, preferred_dim);
41 CHECK_NULL_RETURN(new_const_data);
42 input_tensor->FreeData();
43 input_tensor->set_data(new_const_data);
44 input_tensor->set_own_data(true);
45 input_tensor->set_data_type(dst_data_type);
46 #else
47 MS_LOG(ERROR) << "Float16 is not supported";
48 return RET_NOT_SUPPORT;
49 #endif
50 } else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat32) {
51 auto new_const_data = DequantData<int8_t, float>(input_tensor, preferred_dim);
52 CHECK_NULL_RETURN(new_const_data);
53 input_tensor->FreeData();
54 input_tensor->set_data(new_const_data);
55 input_tensor->set_own_data(true);
56 input_tensor->set_data_type(dst_data_type);
57 } else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat16) {
58 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
59 auto new_const_data = DequantData<int8_t, float16_t>(input_tensor, preferred_dim);
60 CHECK_NULL_RETURN(new_const_data);
61 input_tensor->FreeData();
62 input_tensor->set_data(new_const_data);
63 input_tensor->set_own_data(true);
64 input_tensor->set_data_type(dst_data_type);
65 #else
66 MS_LOG(ERROR) << "Float16 is not supported";
67 return RET_NOT_SUPPORT;
68 #endif
69 } else if (input_tensor->data_type() == kNumberTypeInt32 && dst_data_type == kNumberTypeFloat32) {
70 auto new_const_data = DequantData<int32_t, float>(input_tensor, preferred_dim);
71 CHECK_NULL_RETURN(new_const_data);
72 input_tensor->FreeData();
73 input_tensor->set_data(new_const_data);
74 input_tensor->set_own_data(true);
75 input_tensor->set_data_type(dst_data_type);
76 } else {
77 MS_LOG(ERROR) << "Unsupported dequant from data_type(" << (input_tensor->data_type()) << ") to data_type("
78 << dst_data_type << ")";
79 return RET_NOT_SUPPORT;
80 }
81 return RET_OK;
82 }
83
DecodeKMeansWeight(lite::Tensor * tensor,TypeId dst_data_type=kNumberTypeFloat32)84 int WeightDecoder::DecodeKMeansWeight(lite::Tensor *tensor, TypeId dst_data_type = kNumberTypeFloat32) {
85 void *dequant_data = nullptr;
86 if (dst_data_type == kNumberTypeFloat32) {
87 auto dequant_data_ptr = static_cast<float *>(dequant_data);
88 auto ret = DecodeKMeansData(tensor, &dequant_data_ptr);
89 if (ret != RET_OK) {
90 MS_LOG(ERROR) << "Decode Kmeans data failed.";
91 return RET_ERROR;
92 }
93 dequant_data = dequant_data_ptr;
94 } else if (dst_data_type == kNumberTypeFloat16) {
95 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
96 auto dequant_data_ptr = static_cast<float16_t *>(dequant_data);
97 DecodeKMeansData(tensor, &dequant_data_ptr);
98 dequant_data = dequant_data_ptr;
99 #else
100 MS_LOG(ERROR) << "Current library or hardware don't support FP16.";
101 return RET_ERROR;
102 #endif
103 } else {
104 MS_LOG(ERROR) << dst_data_type << " data type is not support KMeans.";
105 return RET_ERROR;
106 }
107 tensor->FreeData();
108 tensor->set_data(dequant_data);
109 tensor->set_own_data(true);
110 tensor->set_data_type(dst_data_type);
111 return RET_OK;
112 }
113
DecodeHuffmanCode(const SchemaTensorWrapper & src_tensor,lite::Tensor * dst_tensor)114 int WeightDecoder::DecodeHuffmanCode(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor) {
115 MS_ASSERT(src_tensor.handler() != nullptr);
116 MS_ASSERT(src_tensor.data() != nullptr);
117 MS_ASSERT(dst_tensor != nullptr);
118 if (!dst_tensor->IsConst() || !src_tensor.handler()->enableHuffmanCode()) {
119 return RET_NO_CHANGE;
120 }
121 if (src_tensor.data() == nullptr) {
122 return RET_NO_CHANGE;
123 }
124 auto data = reinterpret_cast<const char *>(src_tensor.data());
125 std::string encode_str(data, src_tensor.length());
126 dst_tensor->FreeData();
127 dst_tensor->set_data(nullptr);
128 auto ret = dst_tensor->MallocData();
129 if (ret != RET_OK) {
130 MS_LOG(ERROR) << "Malloc tensor data failed";
131 return RET_NULL_PTR;
132 }
133 auto dst_data = dst_tensor->data();
134 MS_ASSERT(dst_data != nullptr);
135 ret = HuffmanDecode::DoHuffmanDecode(encode_str, dst_data, dst_tensor->Size());
136 if (ret != RET_OK) {
137 MS_LOG(ERROR) << "DoHuffmanDecode failed.";
138 return ret;
139 }
140 return RET_OK;
141 }
142
UnPackToInt(const SchemaTensorWrapper & src_tensor,lite::Tensor * dst_tensor)143 int WeightDecoder::UnPackToInt(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor) {
144 MS_ASSERT(src_tensor.handler() != nullptr);
145 MS_ASSERT(src_tensor.data() != nullptr);
146 MS_ASSERT(dst_tensor != nullptr);
147 auto quant_params = src_tensor.handler()->quantParams();
148 if (quant_params == nullptr || quant_params->size() == 0) {
149 return RET_NO_CHANGE;
150 }
151 auto quant_param = quant_params->Get(0);
152 if (quant_param == nullptr) {
153 return RET_NO_CHANGE;
154 }
155 auto dst_data = dst_tensor->data();
156 if (dst_data != nullptr) {
157 MS_LOG(ERROR) << "lite Tensor has already malloced data";
158 return RET_ERROR;
159 }
160 auto ret = dst_tensor->MallocData();
161 if (ret != RET_OK) {
162 MS_LOG(ERROR) << "Malloc tensor data failed";
163 return RET_NULL_PTR;
164 }
165 dst_data = dst_tensor->data();
166 auto dst_element_num = dst_tensor->ElementsNum();
167 int origin_bit = quant_param->numBits();
168 if (origin_bit < kBitNum8 && origin_bit >= kBitNum1) {
169 return UnPackUtil<int8_t, uint8_t>(src_tensor, dst_element_num, origin_bit, dst_data);
170 } else if (origin_bit < kBitNum16 && origin_bit > kBitNum8) {
171 return UnPackUtil<int16_t, uint16_t>(src_tensor, dst_element_num, origin_bit, dst_data);
172 } else {
173 MS_LOG(ERROR) << "Unsupported bit number: " << origin_bit;
174 return RET_NOT_SUPPORT;
175 }
176 }
177
UnPack(const SchemaTensorWrapper & src_tensor,lite::Tensor * dst_tensor)178 int WeightDecoder::UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor) {
179 MS_ASSERT(src_tensor.handler() != nullptr);
180 MS_ASSERT(src_tensor.data() != nullptr);
181 MS_CHECK_TRUE_MSG(src_tensor.handler()->dims() != nullptr, RET_ERROR, "dims is nullptr");
182 MS_CHECK_TRUE_MSG(src_tensor.handler()->name() != nullptr, RET_ERROR, "name is nullptr");
183 STATUS ret = RET_OK;
184 if (src_tensor.handler()->enableHuffmanCode()) {
185 ret = WeightDecoder::DecodeHuffmanCode(src_tensor, dst_tensor);
186 if (ret != RET_OK && ret != RET_NO_CHANGE) {
187 MS_LOG(ERROR) << "Decode huffman code failed: " << ret;
188 }
189 } else {
190 if (src_tensor.handler()->dims()->size() == 0) {
191 MS_LOG(ERROR) << src_tensor.handler()->name()->c_str() << " shape is empty.";
192 return RET_ERROR;
193 }
194 ret = WeightDecoder::UnPackToInt(src_tensor, dst_tensor);
195 if (ret != RET_OK && ret != RET_NO_CHANGE) {
196 MS_LOG(ERROR) << "Unpack to int8 failed: " << ret;
197 return ret;
198 }
199 }
200 return ret;
201 }
202
SparseDecompress(const SchemaTensorWrapper & src_tensor,Tensor * dst_tensor)203 STATUS WeightDecoder::SparseDecompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor) {
204 MS_ASSERT(src_tensor.handler() != nullptr);
205 MS_ASSERT(src_tensor.data() != nullptr);
206 MS_LOG(DEBUG) << "un-sparse weight";
207 MS_CHECK_TRUE_MSG(src_tensor.handler()->quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
208 MS_CHECK_TRUE_MSG((*src_tensor.handler()->quantParams()).size() > 0, RET_ERROR,
209 "quant params size need bigger than 0");
210 MS_CHECK_TRUE_MSG(src_tensor.handler()->quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
211 size_t bit_num = static_cast<size_t>(src_tensor.handler()->quantParams()->Get(0)->numBits());
212
213 std::string str(static_cast<const char *>(src_tensor.data()), src_tensor.length());
214 auto bit_vec = StringToBitVector(str);
215 size_t index = 0;
216 // parse coor_best_bit
217 size_t coor_best_bit = 0;
218 for (size_t i = 0; i < kBitNum8; i++) {
219 bool bit = bit_vec[index++];
220 coor_best_bit |= bit << static_cast<size_t>((kBitNum8 - i - 1));
221 }
222 // parse nz_cnt
223 size_t nz_cnt = 0;
224 for (size_t i = 0; i < kBitNum32; i++) {
225 bool bit = bit_vec[index++];
226 nz_cnt |= bit << static_cast<size_t>((kBitNum32 - i - 1));
227 }
228 // parse unique_value cnt
229 size_t unique_value_cnt = 0;
230 for (size_t i = 0; i < bit_num; i++) {
231 bool bit = bit_vec[index++];
232 unique_value_cnt |= bit << (bit_num - i - 1);
233 }
234 if (unique_value_cnt == 0) {
235 unique_value_cnt = 1u << bit_num;
236 }
237 // parse unique_values
238 std::vector<int> unique_values;
239 for (size_t i = 0; i < unique_value_cnt; i++) {
240 int unique_value = 0;
241 for (size_t j = 0; j < bit_num; j++) {
242 bool bit = bit_vec[index++];
243 unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
244 }
245 // unsigned to signed
246 unique_values.push_back(unique_value - (1u << static_cast<size_t>((bit_num - 1))));
247 }
248 // parse index
249 std::vector<size_t> unique_value_index_vec;
250 auto elem_cnt = dst_tensor->ElementsNum();
251 size_t unique_value_bit = static_cast<size_t>(ceil(log2(unique_value_cnt)));
252 for (size_t i = 0; i < nz_cnt; i++) {
253 size_t unique_value_index = 0;
254 for (size_t j = 0; j < unique_value_bit; j++) {
255 bool bit = bit_vec[index++];
256 unique_value_index |= bit << (unique_value_bit - j - 1);
257 }
258 unique_value_index_vec.push_back(unique_value_index);
259 }
260
261 // parse coors
262 std::vector<size_t> coor_vec;
263 for (size_t i = 0; i < nz_cnt; i++) {
264 size_t coor = 0;
265 for (size_t j = 0; j < coor_best_bit; j++) {
266 bool bit = bit_vec[index++];
267 coor |= bit << static_cast<size_t>((coor_best_bit - j - 1));
268 }
269 coor_vec.push_back(coor);
270 }
271
272 if (dst_tensor->data() != nullptr) {
273 MS_LOG(ERROR) << "data_c not null";
274 return RET_ERROR;
275 }
276 auto ret = dst_tensor->MallocData();
277 if (ret != RET_OK) {
278 MS_LOG(ERROR) << "Malloc tensor data failed";
279 return RET_NULL_PTR;
280 }
281 auto dst_data = dst_tensor->data();
282
283 if (bit_num <= kBitNum8) {
284 ret =
285 UnSparseTensorData<int8_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.handler()->quantParams(),
286 elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
287 } else {
288 ret =
289 UnSparseTensorData<int16_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.handler()->quantParams(),
290 elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
291 }
292 if (ret != RET_OK) {
293 MS_LOG(ERROR) << "UnSparseTensorData error";
294 return RET_ERROR;
295 }
296 return RET_OK;
297 }
298
StringToBitVector(const std::string & str)299 std::vector<bool> WeightDecoder::StringToBitVector(const std::string &str) {
300 std::vector<bool> vec(str.size() * kBitNum8);
301 size_t index = 0;
302 for (auto ch : str) {
303 for (size_t shift = kBitNum8; shift > 0; shift--) {
304 vec[index++] = (static_cast<unsigned char>(ch) >> static_cast<size_t>(shift - 1)) & 0x1;
305 }
306 }
307 return vec;
308 }
309
IndexingDecompress(const SchemaTensorWrapper & src_tensor,Tensor * dst_tensor)310 STATUS WeightDecoder::IndexingDecompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor) {
311 MS_ASSERT(src_tensor.handler() != nullptr);
312 MS_ASSERT(src_tensor.data() != nullptr);
313 MS_LOG(DEBUG) << "un-index weight";
314 MS_CHECK_TRUE_MSG(src_tensor.handler()->quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
315 MS_CHECK_TRUE_MSG((*src_tensor.handler()->quantParams()).size() > 0, RET_ERROR,
316 "quant params size need bigger than 0");
317 MS_CHECK_TRUE_MSG(src_tensor.handler()->quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
318 auto bit_num = src_tensor.handler()->quantParams()->Get(0)->numBits();
319
320 std::string str(static_cast<const char *>(src_tensor.data()), src_tensor.length());
321 auto bit_vec = StringToBitVector(str);
322 size_t index = 0;
323 // parse unique_value_cnt
324 size_t unique_value_cnt = 0;
325 for (int i = 0; i < bit_num; i++) {
326 bool bit = bit_vec[index++];
327 unique_value_cnt |= bit << static_cast<size_t>((bit_num - i - 1));
328 }
329 if (unique_value_cnt == 0) {
330 unique_value_cnt = 1u << bit_num;
331 }
332 // parse unique_value_set
333 std::vector<int> unique_values;
334 for (size_t i = 0; i < unique_value_cnt; i++) {
335 int unique_value = 0;
336 for (int j = 0; j < bit_num; j++) {
337 bool bit = bit_vec[index++];
338 unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
339 }
340 // unsigned to signed
341 unique_values.push_back(unique_value - (1u << static_cast<size_t>((bit_num - 1))));
342 }
343 // parse index
344 std::vector<size_t> unique_value_index_vec;
345 auto elem_cnt = dst_tensor->ElementsNum();
346 size_t unique_value_bit = static_cast<size_t>(ceil(log2(unique_value_cnt)));
347 for (int i = 0; i < elem_cnt; i++) {
348 size_t unique_value_index = 0;
349 for (size_t j = 0; j < unique_value_bit; j++) {
350 bool bit = bit_vec[index++];
351 unique_value_index |= bit << (static_cast<size_t>(unique_value_bit - j - 1));
352 }
353 unique_value_index_vec.push_back(unique_value_index);
354 }
355
356 MS_CHECK_FALSE_MSG(dst_tensor->data() != nullptr, RET_ERROR, "data_c not null");
357 if (dst_tensor->MallocData() != RET_OK) {
358 MS_LOG(ERROR) << "Malloc tensor data failed";
359 return RET_NULL_PTR;
360 }
361 auto dst_data = dst_tensor->data();
362 int ret;
363 if (bit_num <= kBitNum8) {
364 ret = UnIndexTensorData<int8_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
365 } else {
366 ret = UnIndexTensorData<int16_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
367 }
368 if (ret != RET_OK) {
369 MS_LOG(ERROR) << "UnIndexTensorData error";
370 return RET_ERROR;
371 }
372 return RET_OK;
373 }
374
DequantTensor(Tensor * tensor,int preferred_dim,TypeId dst_data_type)375 int WeightDecoder::DequantTensor(Tensor *tensor, int preferred_dim, TypeId dst_data_type) {
376 MS_ASSERT(tensor != nullptr);
377 if (!tensor->IsConst() ||
378 !(dst_data_type == TypeId::kNumberTypeFloat32 || dst_data_type == TypeId::kNumberTypeFloat16)) {
379 return RET_NO_CHANGE;
380 }
381 if (!tensor->quant_params().empty()) {
382 bool need_dequant = tensor->quant_params().front().inited &&
383 (tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16 ||
384 tensor->data_type() == kNumberTypeInt32);
385 if (!need_dequant) {
386 return RET_NO_CHANGE;
387 }
388 auto ret = WeightDecoder::DequantWeight(tensor, preferred_dim, dst_data_type);
389 if (ret != RET_OK) {
390 MS_LOG(ERROR) << tensor->tensor_name() << " Dequant data failed: " << ret;
391 return ret;
392 }
393 } else if (!tensor->quant_clusters().empty()) {
394 auto ret = DecodeKMeansWeight(tensor, dst_data_type);
395 if (ret != RET_OK) {
396 MS_LOG(ERROR) << tensor->tensor_name() << " Decode KMeans weight failed: " << ret;
397 return ret;
398 }
399 }
400 return RET_OK;
401 }
402
GetMatMulPreferredDim(const OpParameter * op_parameter,int input_index,const std::vector<int> & dims)403 int WeightDecoder::GetMatMulPreferredDim(const OpParameter *op_parameter, int input_index,
404 const std::vector<int> &dims) {
405 int last_first_index = static_cast<int>(dims.size()) - 1;
406 int last_second_index = static_cast<int>(dims.size()) - 2;
407 auto matmul_parameter = reinterpret_cast<const MatMulParameter *>(op_parameter);
408 MS_ASSERT(matmul_parameter != nullptr);
409 // For MatMul A
410 if (input_index == 0) {
411 if (matmul_parameter->a_transpose_) {
412 return last_first_index;
413 } else {
414 return last_second_index;
415 }
416 }
417 // For MatMul B
418 if (input_index == 1) {
419 if (matmul_parameter->b_transpose_) {
420 return last_second_index;
421 } else {
422 return last_first_index;
423 }
424 }
425 return 0;
426 }
427
GetDeConvPreferredDim(const OpParameter * op_parameter,const std::vector<int> & dims)428 int WeightDecoder::GetDeConvPreferredDim(const OpParameter *op_parameter, const std::vector<int> &dims) {
429 MS_ASSERT(op_parameter != nullptr);
430 auto parameter = reinterpret_cast<const ConvParameter *>(op_parameter);
431 if (parameter->input_channel_ == parameter->group_ && parameter->output_channel_ == parameter->group_) {
432 // DepthWise-DeConv (CO\CI) KH KW 1
433 return 0;
434 } else {
435 // DeConv:CI KH KW CO
436 return dims.size() - 1;
437 }
438 }
439
IsChannelFirst(int index,const OpParameter * op_parameter)440 bool WeightDecoder::IsChannelFirst(int index, const OpParameter *op_parameter) {
441 MS_ASSERT(op_parameter != nullptr);
442 if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
443 const auto *param = reinterpret_cast<const MatMulParameter *>(op_parameter);
444 if (index == 0) {
445 return !(param->a_transpose_);
446 } else if (index == 1) {
447 return param->b_transpose_;
448 }
449 }
450 return true;
451 }
452
453 // A * stride_a + bucket_index * stride_b + C
GetDataIndex(const std::vector<int> & dims,int preferred_dim,int bucket_index,int bucket_in_index)454 int WeightDecoder::GetDataIndex(const std::vector<int> &dims, int preferred_dim, int bucket_index,
455 int bucket_in_index) {
456 int stride_a = 1;
457 for (size_t i = static_cast<size_t>(preferred_dim); i < dims.size(); i++) {
458 stride_a *= dims[i];
459 }
460 int stride_b = 1;
461 for (size_t i = static_cast<size_t>(preferred_dim) + 1; i < dims.size(); i++) {
462 stride_b *= dims[i];
463 }
464 MS_ASSERT(stride_b > 0);
465 int A = bucket_in_index / stride_b;
466 int C = bucket_in_index % stride_b;
467 return A * stride_a + bucket_index * stride_b + C;
468 }
469
470 #endif
471
NeedBitUppackCheck(const SchemaTensorWrapper & src_tensor)472 bool NeedBitUppackCheck(const SchemaTensorWrapper &src_tensor) {
473 MS_ASSERT(src_tensor.handler() != nullptr);
474 MS_ASSERT(src_tensor.data() != nullptr);
475 if (src_tensor.handler()->enableHuffmanCode()) {
476 return true;
477 }
478 bool need_bit_unpack = src_tensor.handler()->quantParams() != nullptr &&
479 src_tensor.handler()->quantParams()->size() > 0 &&
480 src_tensor.handler()->quantParams()->Get(0) != nullptr;
481 if (need_bit_unpack) {
482 auto num_bits = src_tensor.handler()->quantParams()->Get(0)->numBits();
483 need_bit_unpack = ((num_bits >= kBitNum1 && num_bits < kBitNum8) || (num_bits > kBitNum8 && num_bits < kBitNum16));
484 }
485
486 return need_bit_unpack;
487 }
488
DequantNode(const OpParameter * op_parameter,const std::vector<Tensor * > & in_tensors,TypeId dst_data_type,const std::string & model_version,bool float_mode)489 int WeightDecoder::DequantNode(const OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors,
490 TypeId dst_data_type, const std::string &model_version, bool float_mode) {
491 #ifndef WEIGHT_DECODE_CLIP
492 if (op_parameter->quant_type_ != static_cast<int>(schema::QuantType_QUANT_WEIGHT) &&
493 !(op_parameter->quant_type_ == static_cast<int>(schema::QuantType_QUANT_ALL) && float_mode)) {
494 return RET_OK;
495 }
496 int index = 0;
497 for (auto &tensor : in_tensors) {
498 MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
499 auto preferred_dim = GetPreferredDim(in_tensors, op_parameter, index++, tensor->shape(), model_version);
500 auto ret = WeightDecoder::DequantTensor(tensor, preferred_dim, dst_data_type);
501 if (ret != RET_OK && ret != RET_NO_CHANGE) {
502 MS_LOG(DEBUG) << "Dequant tensor failed";
503 return RET_ERROR;
504 }
505 tensor->ClearQuantParam();
506 }
507 return RET_OK;
508 #else
509 if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT &&
510 !(op_parameter->quant_type_ == schema::QuantType_QUANT_ALL && float_mode)) {
511 return RET_OK;
512 } else {
513 MS_LOG(ERROR) << "Do not support dequant node.";
514 return RET_NOT_SUPPORT;
515 }
516 #endif
517 }
518
DecompressTensor(const SchemaTensorWrapper & src_tensor,lite::Tensor * dst_tensor)519 int WeightDecoder::DecompressTensor(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor) {
520 MS_ASSERT(src_tensor.handler() != nullptr);
521 MS_ASSERT(dst_tensor != nullptr);
522 #ifndef WEIGHT_DECODE_CLIP
523 if (src_tensor.handler()->weightQuantCompressType() == schema::WeightQuantCompressType_FSE ||
524 src_tensor.handler()->weightQuantCompressType() == schema::WeightQuantCompressType_FSE_INT) {
525 return quant::FSEDecoder::DeCompress(src_tensor, dst_tensor, src_tensor.handler()->weightQuantCompressType());
526 } else if (src_tensor.handler()->weightQuantCompressType() == schema::WeightQuantCompressType_INDEXING) {
527 return IndexingDecompress(src_tensor, dst_tensor);
528 } else if (src_tensor.handler()->weightQuantCompressType() == schema::WeightQuantCompressType_SPARSE) {
529 return SparseDecompress(src_tensor, dst_tensor);
530 }
531 if (!NeedBitUppackCheck(src_tensor)) {
532 return RET_NO_CHANGE;
533 } else {
534 return WeightDecoder::UnPack(src_tensor, dst_tensor);
535 }
536 #else
537 if (src_tensor.handler()->weightQuantCompressType() != schema::WeightQuantCompressType_NONE) {
538 MS_LOG(ERROR) << unsupport_weight_decode_log;
539 return RET_ERROR;
540 }
541 if (NeedBitUppackCheck(src_tensor)) {
542 MS_LOG(ERROR) << unsupport_weight_decode_log;
543 return RET_ERROR;
544 }
545 return RET_NO_CHANGE;
546 #endif
547 }
548 } // namespace mindspore::lite
549