• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <cmath>
17 #include <string>
18 #include "src/litert/weight_decoder.h"
19 #include "src/litert/huffman_decode.h"
20 #include "tools/converter/quantizer/fse_decoder.h"
21 #include "nnacl/conv_parameter.h"
22 
23 namespace mindspore::lite {
24 #ifndef WEIGHT_DECODE_CLIP
DequantWeight(lite::Tensor * input_tensor,int preferred_dim,TypeId dst_data_type)25 int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, int preferred_dim, TypeId dst_data_type) {
26   MS_ASSERT(input_tensor != nullptr);
27   if (input_tensor->quant_params().empty()) {
28     MS_LOG(ERROR) << "No quant param.";
29     return RET_ERROR;
30   }
31   if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat32) {
32     auto new_const_data = DequantData<int16_t, float>(input_tensor, preferred_dim);
33     CHECK_NULL_RETURN(new_const_data);
34     input_tensor->FreeData();
35     input_tensor->set_data(new_const_data);
36     input_tensor->set_own_data(true);
37     input_tensor->set_data_type(dst_data_type);
38   } else if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat16) {
39 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
40     auto new_const_data = DequantData<int16_t, float16_t>(input_tensor, preferred_dim);
41     CHECK_NULL_RETURN(new_const_data);
42     input_tensor->FreeData();
43     input_tensor->set_data(new_const_data);
44     input_tensor->set_own_data(true);
45     input_tensor->set_data_type(dst_data_type);
46 #else
47     MS_LOG(ERROR) << "Float16 is not supported";
48     return RET_NOT_SUPPORT;
49 #endif
50   } else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat32) {
51     auto new_const_data = DequantData<int8_t, float>(input_tensor, preferred_dim);
52     CHECK_NULL_RETURN(new_const_data);
53     input_tensor->FreeData();
54     input_tensor->set_data(new_const_data);
55     input_tensor->set_own_data(true);
56     input_tensor->set_data_type(dst_data_type);
57   } else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat16) {
58 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
59     auto new_const_data = DequantData<int8_t, float16_t>(input_tensor, preferred_dim);
60     CHECK_NULL_RETURN(new_const_data);
61     input_tensor->FreeData();
62     input_tensor->set_data(new_const_data);
63     input_tensor->set_own_data(true);
64     input_tensor->set_data_type(dst_data_type);
65 #else
66     MS_LOG(ERROR) << "Float16 is not supported";
67     return RET_NOT_SUPPORT;
68 #endif
69   } else if (input_tensor->data_type() == kNumberTypeInt32 && dst_data_type == kNumberTypeFloat32) {
70     auto new_const_data = DequantData<int32_t, float>(input_tensor, preferred_dim);
71     CHECK_NULL_RETURN(new_const_data);
72     input_tensor->FreeData();
73     input_tensor->set_data(new_const_data);
74     input_tensor->set_own_data(true);
75     input_tensor->set_data_type(dst_data_type);
76   } else {
77     MS_LOG(ERROR) << "Unsupported dequant from data_type(" << (input_tensor->data_type()) << ") to data_type("
78                   << dst_data_type << ")";
79     return RET_NOT_SUPPORT;
80   }
81   return RET_OK;
82 }
83 
DecodeKMeansWeight(lite::Tensor * tensor,TypeId dst_data_type=kNumberTypeFloat32)84 int WeightDecoder::DecodeKMeansWeight(lite::Tensor *tensor, TypeId dst_data_type = kNumberTypeFloat32) {
85   void *dequant_data = nullptr;
86   if (dst_data_type == kNumberTypeFloat32) {
87     auto dequant_data_ptr = static_cast<float *>(dequant_data);
88     auto ret = DecodeKMeansData(tensor, &dequant_data_ptr);
89     if (ret != RET_OK) {
90       MS_LOG(ERROR) << "Decode Kmeans data failed.";
91       return RET_ERROR;
92     }
93     dequant_data = dequant_data_ptr;
94   } else if (dst_data_type == kNumberTypeFloat16) {
95 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
96     auto dequant_data_ptr = static_cast<float16_t *>(dequant_data);
97     DecodeKMeansData(tensor, &dequant_data_ptr);
98     dequant_data = dequant_data_ptr;
99 #else
100     MS_LOG(ERROR) << "Current library or hardware don't support FP16.";
101     return RET_ERROR;
102 #endif
103   } else {
104     MS_LOG(ERROR) << dst_data_type << " data type is not support KMeans.";
105     return RET_ERROR;
106   }
107   tensor->FreeData();
108   tensor->set_data(dequant_data);
109   tensor->set_own_data(true);
110   tensor->set_data_type(dst_data_type);
111   return RET_OK;
112 }
113 
DecodeHuffmanCode(const SchemaTensorWrapper & src_tensor,lite::Tensor * dst_tensor)114 int WeightDecoder::DecodeHuffmanCode(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor) {
115   MS_ASSERT(src_tensor.handler() != nullptr);
116   MS_ASSERT(src_tensor.data() != nullptr);
117   MS_ASSERT(dst_tensor != nullptr);
118   if (!dst_tensor->IsConst() || !src_tensor.handler()->enableHuffmanCode()) {
119     return RET_NO_CHANGE;
120   }
121   if (src_tensor.data() == nullptr) {
122     return RET_NO_CHANGE;
123   }
124   auto data = reinterpret_cast<const char *>(src_tensor.data());
125   std::string encode_str(data, src_tensor.length());
126   dst_tensor->FreeData();
127   dst_tensor->set_data(nullptr);
128   auto ret = dst_tensor->MallocData();
129   if (ret != RET_OK) {
130     MS_LOG(ERROR) << "Malloc tensor data failed";
131     return RET_NULL_PTR;
132   }
133   auto dst_data = dst_tensor->data();
134   MS_ASSERT(dst_data != nullptr);
135   ret = HuffmanDecode::DoHuffmanDecode(encode_str, dst_data, dst_tensor->Size());
136   if (ret != RET_OK) {
137     MS_LOG(ERROR) << "DoHuffmanDecode failed.";
138     return ret;
139   }
140   return RET_OK;
141 }
142 
UnPackToInt(const SchemaTensorWrapper & src_tensor,lite::Tensor * dst_tensor)143 int WeightDecoder::UnPackToInt(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor) {
144   MS_ASSERT(src_tensor.handler() != nullptr);
145   MS_ASSERT(src_tensor.data() != nullptr);
146   MS_ASSERT(dst_tensor != nullptr);
147   auto quant_params = src_tensor.handler()->quantParams();
148   if (quant_params == nullptr || quant_params->size() == 0) {
149     return RET_NO_CHANGE;
150   }
151   auto quant_param = quant_params->Get(0);
152   if (quant_param == nullptr) {
153     return RET_NO_CHANGE;
154   }
155   auto dst_data = dst_tensor->data();
156   if (dst_data != nullptr) {
157     MS_LOG(ERROR) << "lite Tensor has already malloced data";
158     return RET_ERROR;
159   }
160   auto ret = dst_tensor->MallocData();
161   if (ret != RET_OK) {
162     MS_LOG(ERROR) << "Malloc tensor data failed";
163     return RET_NULL_PTR;
164   }
165   dst_data = dst_tensor->data();
166   auto dst_element_num = dst_tensor->ElementsNum();
167   int origin_bit = quant_param->numBits();
168   if (origin_bit < kBitNum8 && origin_bit >= kBitNum1) {
169     return UnPackUtil<int8_t, uint8_t>(src_tensor, dst_element_num, origin_bit, dst_data);
170   } else if (origin_bit < kBitNum16 && origin_bit > kBitNum8) {
171     return UnPackUtil<int16_t, uint16_t>(src_tensor, dst_element_num, origin_bit, dst_data);
172   } else {
173     MS_LOG(ERROR) << "Unsupported bit number: " << origin_bit;
174     return RET_NOT_SUPPORT;
175   }
176 }
177 
UnPack(const SchemaTensorWrapper & src_tensor,lite::Tensor * dst_tensor)178 int WeightDecoder::UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor) {
179   MS_ASSERT(src_tensor.handler() != nullptr);
180   MS_ASSERT(src_tensor.data() != nullptr);
181   MS_CHECK_TRUE_MSG(src_tensor.handler()->dims() != nullptr, RET_ERROR, "dims is nullptr");
182   MS_CHECK_TRUE_MSG(src_tensor.handler()->name() != nullptr, RET_ERROR, "name is nullptr");
183   STATUS ret = RET_OK;
184   if (src_tensor.handler()->enableHuffmanCode()) {
185     ret = WeightDecoder::DecodeHuffmanCode(src_tensor, dst_tensor);
186     if (ret != RET_OK && ret != RET_NO_CHANGE) {
187       MS_LOG(ERROR) << "Decode huffman code failed: " << ret;
188     }
189   } else {
190     if (src_tensor.handler()->dims()->size() == 0) {
191       MS_LOG(ERROR) << src_tensor.handler()->name()->c_str() << " shape is empty.";
192       return RET_ERROR;
193     }
194     ret = WeightDecoder::UnPackToInt(src_tensor, dst_tensor);
195     if (ret != RET_OK && ret != RET_NO_CHANGE) {
196       MS_LOG(ERROR) << "Unpack to int8 failed: " << ret;
197       return ret;
198     }
199   }
200   return ret;
201 }
202 
SparseDecompress(const SchemaTensorWrapper & src_tensor,Tensor * dst_tensor)203 STATUS WeightDecoder::SparseDecompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor) {
204   MS_ASSERT(src_tensor.handler() != nullptr);
205   MS_ASSERT(src_tensor.data() != nullptr);
206   MS_LOG(DEBUG) << "un-sparse weight";
207   MS_CHECK_TRUE_MSG(src_tensor.handler()->quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
208   MS_CHECK_TRUE_MSG((*src_tensor.handler()->quantParams()).size() > 0, RET_ERROR,
209                     "quant params size need bigger than 0");
210   MS_CHECK_TRUE_MSG(src_tensor.handler()->quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
211   size_t bit_num = static_cast<size_t>(src_tensor.handler()->quantParams()->Get(0)->numBits());
212 
213   std::string str(static_cast<const char *>(src_tensor.data()), src_tensor.length());
214   auto bit_vec = StringToBitVector(str);
215   size_t index = 0;
216   // parse coor_best_bit
217   size_t coor_best_bit = 0;
218   for (size_t i = 0; i < kBitNum8; i++) {
219     bool bit = bit_vec[index++];
220     coor_best_bit |= bit << static_cast<size_t>((kBitNum8 - i - 1));
221   }
222   // parse nz_cnt
223   size_t nz_cnt = 0;
224   for (size_t i = 0; i < kBitNum32; i++) {
225     bool bit = bit_vec[index++];
226     nz_cnt |= bit << static_cast<size_t>((kBitNum32 - i - 1));
227   }
228   // parse unique_value cnt
229   size_t unique_value_cnt = 0;
230   for (size_t i = 0; i < bit_num; i++) {
231     bool bit = bit_vec[index++];
232     unique_value_cnt |= bit << (bit_num - i - 1);
233   }
234   if (unique_value_cnt == 0) {
235     unique_value_cnt = 1u << bit_num;
236   }
237   // parse unique_values
238   std::vector<int> unique_values;
239   for (size_t i = 0; i < unique_value_cnt; i++) {
240     int unique_value = 0;
241     for (size_t j = 0; j < bit_num; j++) {
242       bool bit = bit_vec[index++];
243       unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
244     }
245     // unsigned to signed
246     unique_values.push_back(unique_value - (1u << static_cast<size_t>((bit_num - 1))));
247   }
248   // parse index
249   std::vector<size_t> unique_value_index_vec;
250   auto elem_cnt = dst_tensor->ElementsNum();
251   size_t unique_value_bit = static_cast<size_t>(ceil(log2(unique_value_cnt)));
252   for (size_t i = 0; i < nz_cnt; i++) {
253     size_t unique_value_index = 0;
254     for (size_t j = 0; j < unique_value_bit; j++) {
255       bool bit = bit_vec[index++];
256       unique_value_index |= bit << (unique_value_bit - j - 1);
257     }
258     unique_value_index_vec.push_back(unique_value_index);
259   }
260 
261   // parse coors
262   std::vector<size_t> coor_vec;
263   for (size_t i = 0; i < nz_cnt; i++) {
264     size_t coor = 0;
265     for (size_t j = 0; j < coor_best_bit; j++) {
266       bool bit = bit_vec[index++];
267       coor |= bit << static_cast<size_t>((coor_best_bit - j - 1));
268     }
269     coor_vec.push_back(coor);
270   }
271 
272   if (dst_tensor->data() != nullptr) {
273     MS_LOG(ERROR) << "data_c not null";
274     return RET_ERROR;
275   }
276   auto ret = dst_tensor->MallocData();
277   if (ret != RET_OK) {
278     MS_LOG(ERROR) << "Malloc tensor data failed";
279     return RET_NULL_PTR;
280   }
281   auto dst_data = dst_tensor->data();
282 
283   if (bit_num <= kBitNum8) {
284     ret =
285       UnSparseTensorData<int8_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.handler()->quantParams(),
286                                  elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
287   } else {
288     ret =
289       UnSparseTensorData<int16_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.handler()->quantParams(),
290                                   elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
291   }
292   if (ret != RET_OK) {
293     MS_LOG(ERROR) << "UnSparseTensorData error";
294     return RET_ERROR;
295   }
296   return RET_OK;
297 }
298 
StringToBitVector(const std::string & str)299 std::vector<bool> WeightDecoder::StringToBitVector(const std::string &str) {
300   std::vector<bool> vec(str.size() * kBitNum8);
301   size_t index = 0;
302   for (auto ch : str) {
303     for (size_t shift = kBitNum8; shift > 0; shift--) {
304       vec[index++] = (static_cast<unsigned char>(ch) >> static_cast<size_t>(shift - 1)) & 0x1;
305     }
306   }
307   return vec;
308 }
309 
IndexingDecompress(const SchemaTensorWrapper & src_tensor,Tensor * dst_tensor)310 STATUS WeightDecoder::IndexingDecompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor) {
311   MS_ASSERT(src_tensor.handler() != nullptr);
312   MS_ASSERT(src_tensor.data() != nullptr);
313   MS_LOG(DEBUG) << "un-index weight";
314   MS_CHECK_TRUE_MSG(src_tensor.handler()->quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
315   MS_CHECK_TRUE_MSG((*src_tensor.handler()->quantParams()).size() > 0, RET_ERROR,
316                     "quant params size need bigger than 0");
317   MS_CHECK_TRUE_MSG(src_tensor.handler()->quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
318   auto bit_num = src_tensor.handler()->quantParams()->Get(0)->numBits();
319 
320   std::string str(static_cast<const char *>(src_tensor.data()), src_tensor.length());
321   auto bit_vec = StringToBitVector(str);
322   size_t index = 0;
323   // parse unique_value_cnt
324   size_t unique_value_cnt = 0;
325   for (int i = 0; i < bit_num; i++) {
326     bool bit = bit_vec[index++];
327     unique_value_cnt |= bit << static_cast<size_t>((bit_num - i - 1));
328   }
329   if (unique_value_cnt == 0) {
330     unique_value_cnt = 1u << bit_num;
331   }
332   // parse unique_value_set
333   std::vector<int> unique_values;
334   for (size_t i = 0; i < unique_value_cnt; i++) {
335     int unique_value = 0;
336     for (int j = 0; j < bit_num; j++) {
337       bool bit = bit_vec[index++];
338       unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
339     }
340     // unsigned to signed
341     unique_values.push_back(unique_value - (1u << static_cast<size_t>((bit_num - 1))));
342   }
343   // parse index
344   std::vector<size_t> unique_value_index_vec;
345   auto elem_cnt = dst_tensor->ElementsNum();
346   size_t unique_value_bit = static_cast<size_t>(ceil(log2(unique_value_cnt)));
347   for (int i = 0; i < elem_cnt; i++) {
348     size_t unique_value_index = 0;
349     for (size_t j = 0; j < unique_value_bit; j++) {
350       bool bit = bit_vec[index++];
351       unique_value_index |= bit << (static_cast<size_t>(unique_value_bit - j - 1));
352     }
353     unique_value_index_vec.push_back(unique_value_index);
354   }
355 
356   MS_CHECK_FALSE_MSG(dst_tensor->data() != nullptr, RET_ERROR, "data_c not null");
357   if (dst_tensor->MallocData() != RET_OK) {
358     MS_LOG(ERROR) << "Malloc tensor data failed";
359     return RET_NULL_PTR;
360   }
361   auto dst_data = dst_tensor->data();
362   int ret;
363   if (bit_num <= kBitNum8) {
364     ret = UnIndexTensorData<int8_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
365   } else {
366     ret = UnIndexTensorData<int16_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
367   }
368   if (ret != RET_OK) {
369     MS_LOG(ERROR) << "UnIndexTensorData error";
370     return RET_ERROR;
371   }
372   return RET_OK;
373 }
374 
DequantTensor(Tensor * tensor,int preferred_dim,TypeId dst_data_type)375 int WeightDecoder::DequantTensor(Tensor *tensor, int preferred_dim, TypeId dst_data_type) {
376   MS_ASSERT(tensor != nullptr);
377   if (!tensor->IsConst() ||
378       !(dst_data_type == TypeId::kNumberTypeFloat32 || dst_data_type == TypeId::kNumberTypeFloat16)) {
379     return RET_NO_CHANGE;
380   }
381   if (!tensor->quant_params().empty()) {
382     bool need_dequant = tensor->quant_params().front().inited &&
383                         (tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16 ||
384                          tensor->data_type() == kNumberTypeInt32);
385     if (!need_dequant) {
386       return RET_NO_CHANGE;
387     }
388     auto ret = WeightDecoder::DequantWeight(tensor, preferred_dim, dst_data_type);
389     if (ret != RET_OK) {
390       MS_LOG(ERROR) << tensor->tensor_name() << " Dequant data failed: " << ret;
391       return ret;
392     }
393   } else if (!tensor->quant_clusters().empty()) {
394     auto ret = DecodeKMeansWeight(tensor, dst_data_type);
395     if (ret != RET_OK) {
396       MS_LOG(ERROR) << tensor->tensor_name() << " Decode KMeans weight failed: " << ret;
397       return ret;
398     }
399   }
400   return RET_OK;
401 }
402 
GetMatMulPreferredDim(const OpParameter * op_parameter,int input_index,const std::vector<int> & dims)403 int WeightDecoder::GetMatMulPreferredDim(const OpParameter *op_parameter, int input_index,
404                                          const std::vector<int> &dims) {
405   int last_first_index = static_cast<int>(dims.size()) - 1;
406   int last_second_index = static_cast<int>(dims.size()) - 2;
407   auto matmul_parameter = reinterpret_cast<const MatMulParameter *>(op_parameter);
408   MS_ASSERT(matmul_parameter != nullptr);
409   // For MatMul A
410   if (input_index == 0) {
411     if (matmul_parameter->a_transpose_) {
412       return last_first_index;
413     } else {
414       return last_second_index;
415     }
416   }
417   // For MatMul B
418   if (input_index == 1) {
419     if (matmul_parameter->b_transpose_) {
420       return last_second_index;
421     } else {
422       return last_first_index;
423     }
424   }
425   return 0;
426 }
427 
GetDeConvPreferredDim(const OpParameter * op_parameter,const std::vector<int> & dims)428 int WeightDecoder::GetDeConvPreferredDim(const OpParameter *op_parameter, const std::vector<int> &dims) {
429   MS_ASSERT(op_parameter != nullptr);
430   auto parameter = reinterpret_cast<const ConvParameter *>(op_parameter);
431   if (parameter->input_channel_ == parameter->group_ && parameter->output_channel_ == parameter->group_) {
432     // DepthWise-DeConv (CO\CI) KH KW 1
433     return 0;
434   } else {
435     // DeConv:CI KH KW CO
436     return dims.size() - 1;
437   }
438 }
439 
IsChannelFirst(int index,const OpParameter * op_parameter)440 bool WeightDecoder::IsChannelFirst(int index, const OpParameter *op_parameter) {
441   MS_ASSERT(op_parameter != nullptr);
442   if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
443     const auto *param = reinterpret_cast<const MatMulParameter *>(op_parameter);
444     if (index == 0) {
445       return !(param->a_transpose_);
446     } else if (index == 1) {
447       return param->b_transpose_;
448     }
449   }
450   return true;
451 }
452 
453 // A * stride_a + bucket_index * stride_b + C
GetDataIndex(const std::vector<int> & dims,int preferred_dim,int bucket_index,int bucket_in_index)454 int WeightDecoder::GetDataIndex(const std::vector<int> &dims, int preferred_dim, int bucket_index,
455                                 int bucket_in_index) {
456   int stride_a = 1;
457   for (size_t i = static_cast<size_t>(preferred_dim); i < dims.size(); i++) {
458     stride_a *= dims[i];
459   }
460   int stride_b = 1;
461   for (size_t i = static_cast<size_t>(preferred_dim) + 1; i < dims.size(); i++) {
462     stride_b *= dims[i];
463   }
464   MS_ASSERT(stride_b > 0);
465   int A = bucket_in_index / stride_b;
466   int C = bucket_in_index % stride_b;
467   return A * stride_a + bucket_index * stride_b + C;
468 }
469 
470 #endif
471 
NeedBitUppackCheck(const SchemaTensorWrapper & src_tensor)472 bool NeedBitUppackCheck(const SchemaTensorWrapper &src_tensor) {
473   MS_ASSERT(src_tensor.handler() != nullptr);
474   MS_ASSERT(src_tensor.data() != nullptr);
475   if (src_tensor.handler()->enableHuffmanCode()) {
476     return true;
477   }
478   bool need_bit_unpack = src_tensor.handler()->quantParams() != nullptr &&
479                          src_tensor.handler()->quantParams()->size() > 0 &&
480                          src_tensor.handler()->quantParams()->Get(0) != nullptr;
481   if (need_bit_unpack) {
482     auto num_bits = src_tensor.handler()->quantParams()->Get(0)->numBits();
483     need_bit_unpack = ((num_bits >= kBitNum1 && num_bits < kBitNum8) || (num_bits > kBitNum8 && num_bits < kBitNum16));
484   }
485 
486   return need_bit_unpack;
487 }
488 
DequantNode(const OpParameter * op_parameter,const std::vector<Tensor * > & in_tensors,TypeId dst_data_type,const std::string & model_version,bool float_mode)489 int WeightDecoder::DequantNode(const OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors,
490                                TypeId dst_data_type, const std::string &model_version, bool float_mode) {
491 #ifndef WEIGHT_DECODE_CLIP
492   if (op_parameter->quant_type_ != static_cast<int>(schema::QuantType_QUANT_WEIGHT) &&
493       !(op_parameter->quant_type_ == static_cast<int>(schema::QuantType_QUANT_ALL) && float_mode)) {
494     return RET_OK;
495   }
496   int index = 0;
497   for (auto &tensor : in_tensors) {
498     MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
499     auto preferred_dim = GetPreferredDim(in_tensors, op_parameter, index++, tensor->shape(), model_version);
500     auto ret = WeightDecoder::DequantTensor(tensor, preferred_dim, dst_data_type);
501     if (ret != RET_OK && ret != RET_NO_CHANGE) {
502       MS_LOG(DEBUG) << "Dequant tensor failed";
503       return RET_ERROR;
504     }
505     tensor->ClearQuantParam();
506   }
507   return RET_OK;
508 #else
509   if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT &&
510       !(op_parameter->quant_type_ == schema::QuantType_QUANT_ALL && float_mode)) {
511     return RET_OK;
512   } else {
513     MS_LOG(ERROR) << "Do not support dequant node.";
514     return RET_NOT_SUPPORT;
515   }
516 #endif
517 }
518 
DecompressTensor(const SchemaTensorWrapper & src_tensor,lite::Tensor * dst_tensor)519 int WeightDecoder::DecompressTensor(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor) {
520   MS_ASSERT(src_tensor.handler() != nullptr);
521   MS_ASSERT(dst_tensor != nullptr);
522 #ifndef WEIGHT_DECODE_CLIP
523   if (src_tensor.handler()->weightQuantCompressType() == schema::WeightQuantCompressType_FSE ||
524       src_tensor.handler()->weightQuantCompressType() == schema::WeightQuantCompressType_FSE_INT) {
525     return quant::FSEDecoder::DeCompress(src_tensor, dst_tensor, src_tensor.handler()->weightQuantCompressType());
526   } else if (src_tensor.handler()->weightQuantCompressType() == schema::WeightQuantCompressType_INDEXING) {
527     return IndexingDecompress(src_tensor, dst_tensor);
528   } else if (src_tensor.handler()->weightQuantCompressType() == schema::WeightQuantCompressType_SPARSE) {
529     return SparseDecompress(src_tensor, dst_tensor);
530   }
531   if (!NeedBitUppackCheck(src_tensor)) {
532     return RET_NO_CHANGE;
533   } else {
534     return WeightDecoder::UnPack(src_tensor, dst_tensor);
535   }
536 #else
537   if (src_tensor.handler()->weightQuantCompressType() != schema::WeightQuantCompressType_NONE) {
538     MS_LOG(ERROR) << unsupport_weight_decode_log;
539     return RET_ERROR;
540   }
541   if (NeedBitUppackCheck(src_tensor)) {
542     MS_LOG(ERROR) << unsupport_weight_decode_log;
543     return RET_ERROR;
544   }
545   return RET_NO_CHANGE;
546 #endif
547 }
548 }  // namespace mindspore::lite
549