• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <cmath>
17 #include <string>
18 #include "src/weight_decoder.h"
19 #include "src/huffman_decode.h"
20 
21 namespace mindspore::lite {
22 constexpr int kBit8 = 8;
23 constexpr int kBit32 = 32;
StringToBitVector(const std::string & str)24 std::vector<bool> StringToBitVector(const std::string &str) {
25   std::vector<bool> vec(str.size() * kBit8);
26   size_t index = 0;
27   for (auto ch : str) {
28     for (size_t shift = kBit8; shift > 0; shift--) {
29       vec[index++] = (ch >> static_cast<size_t>(shift - 1)) & 0x1;
30     }
31   }
32   return vec;
33 }
34 
IndexingDecompress(const schema::Tensor & src_tensor,Tensor * dst_tensor)35 STATUS IndexingDecompress(const schema::Tensor &src_tensor, Tensor *dst_tensor) {
36   MS_LOG(DEBUG) << "un-index weight";
37   MS_CHECK_TRUE_MSG(src_tensor.quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
38   MS_CHECK_TRUE_MSG((*src_tensor.quantParams()).size() > 0, RET_ERROR, "quant params size need bigger than 0");
39   MS_CHECK_TRUE_MSG(src_tensor.quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
40   auto bit_num = src_tensor.quantParams()->Get(0)->numBits();
41 
42   std::string str(reinterpret_cast<const char *>(src_tensor.data()->data()), src_tensor.data()->size());
43   auto bit_vec = StringToBitVector(str);
44   size_t index = 0;
45   // parse unique_value_cnt
46   size_t unique_value_cnt = 0;
47   for (int i = 0; i < bit_num; i++) {
48     bool bit = bit_vec[index++];
49     unique_value_cnt |= bit << static_cast<size_t>((bit_num - i - 1));
50   }
51   if (unique_value_cnt == 0) {
52     unique_value_cnt = 1 << bit_num;
53   }
54   // parse unique_value_set
55   std::vector<int> unique_values;
56   for (size_t i = 0; i < unique_value_cnt; i++) {
57     int unique_value = 0;
58     for (int j = 0; j < bit_num; j++) {
59       bool bit = bit_vec[index++];
60       unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
61     }
62     // unsigned to signed
63     unique_values.push_back(unique_value - (1 << static_cast<size_t>((bit_num - 1))));
64   }
65   // parse index
66   std::vector<size_t> unique_value_index_vec;
67   auto elem_cnt = dst_tensor->ElementsNum();
68   size_t unique_value_bit = ceil(log2(unique_value_cnt));
69   for (int i = 0; i < elem_cnt; i++) {
70     size_t unique_value_index = 0;
71     for (size_t j = 0; j < unique_value_bit; j++) {
72       bool bit = bit_vec[index++];
73       unique_value_index |= bit << (static_cast<size_t>(unique_value_bit - j - 1));
74     }
75     unique_value_index_vec.push_back(unique_value_index);
76   }
77 
78   if (dst_tensor->data() != nullptr) {
79     MS_LOG(ERROR) << "data_c not null";
80     return RET_ERROR;
81   }
82   auto ret = dst_tensor->MallocData();
83   if (ret != RET_OK) {
84     MS_LOG(ERROR) << "Malloc tensor data failed";
85     return RET_NULL_PTR;
86   }
87   auto dst_data = dst_tensor->data();
88   if (bit_num <= kBit8) {
89     ret = UnIndexTensorData<int8_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
90   } else {
91     ret = UnIndexTensorData<int16_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
92   }
93   if (ret != RET_OK) {
94     MS_LOG(ERROR) << "UnIndexTensorData error";
95     return RET_ERROR;
96   }
97   return RET_OK;
98 }
99 
SparseDecompress(const schema::Tensor & src_tensor,Tensor * dst_tensor)100 STATUS SparseDecompress(const schema::Tensor &src_tensor, Tensor *dst_tensor) {
101   MS_LOG(DEBUG) << "un-sparse weight";
102   MS_CHECK_TRUE_MSG(src_tensor.quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
103   MS_CHECK_TRUE_MSG((*src_tensor.quantParams()).size() > 0, RET_ERROR, "quant params size need bigger than 0");
104   MS_CHECK_TRUE_MSG(src_tensor.quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
105   size_t bit_num = src_tensor.quantParams()->Get(0)->numBits();
106 
107   std::string str(reinterpret_cast<const char *>(src_tensor.data()->data()), src_tensor.data()->size());
108   auto bit_vec = StringToBitVector(str);
109   size_t index = 0;
110   // parse coor_best_bit
111   size_t coor_best_bit = 0;
112   for (size_t i = 0; i < kBit8; i++) {
113     bool bit = bit_vec[index++];
114     coor_best_bit |= bit << static_cast<size_t>((kBit8 - i - 1));
115   }
116   // parse nz_cnt
117   size_t nz_cnt = 0;
118   for (size_t i = 0; i < kBit32; i++) {
119     bool bit = bit_vec[index++];
120     nz_cnt |= bit << static_cast<size_t>((kBit32 - i - 1));
121   }
122   // parse unique_value cnt
123   size_t unique_value_cnt = 0;
124   for (size_t i = 0; i < bit_num; i++) {
125     bool bit = bit_vec[index++];
126     unique_value_cnt |= bit << static_cast<size_t>((bit_num - i - 1));
127   }
128   if (unique_value_cnt == 0) {
129     unique_value_cnt = 1 << bit_num;
130   }
131   // parse unique_values
132   std::vector<int> unique_values;
133   for (size_t i = 0; i < unique_value_cnt; i++) {
134     int unique_value = 0;
135     for (size_t j = 0; j < bit_num; j++) {
136       bool bit = bit_vec[index++];
137       unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
138     }
139     // unsigned to signed
140     unique_values.push_back(unique_value - (1 << static_cast<size_t>((bit_num - 1))));
141   }
142   // parse index
143   std::vector<size_t> unique_value_index_vec;
144   auto elem_cnt = dst_tensor->ElementsNum();
145   size_t unique_value_bit = static_cast<size_t>(ceil(log2(unique_value_cnt)));
146   for (size_t i = 0; i < nz_cnt; i++) {
147     size_t unique_value_index = 0;
148     for (size_t j = 0; j < unique_value_bit; j++) {
149       bool bit = bit_vec[index++];
150       unique_value_index |= bit << (unique_value_bit - j - 1);
151     }
152     unique_value_index_vec.push_back(unique_value_index);
153   }
154 
155   // parse coors
156   std::vector<size_t> coor_vec;
157   for (size_t i = 0; i < nz_cnt; i++) {
158     size_t coor = 0;
159     for (size_t j = 0; j < coor_best_bit; j++) {
160       bool bit = bit_vec[index++];
161       coor |= bit << static_cast<size_t>((coor_best_bit - j - 1));
162     }
163     coor_vec.push_back(coor);
164   }
165 
166   if (dst_tensor->data() != nullptr) {
167     MS_LOG(ERROR) << "data_c not null";
168     return RET_ERROR;
169   }
170   auto ret = dst_tensor->MallocData();
171   if (ret != RET_OK) {
172     MS_LOG(ERROR) << "Malloc tensor data failed";
173     return RET_NULL_PTR;
174   }
175   auto dst_data = dst_tensor->data();
176 
177   if (bit_num <= kBit8) {
178     ret = UnSparseTensorData<int8_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.quantParams(),
179                                      elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
180   } else {
181     ret = UnSparseTensorData<int16_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.quantParams(),
182                                       elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
183   }
184   if (ret != RET_OK) {
185     MS_LOG(ERROR) << "UnSparseTensorData error";
186     return RET_ERROR;
187   }
188   return RET_OK;
189 }
190 
DequantWeight(lite::Tensor * input_tensor,bool channel_first,TypeId dst_data_type)191 int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) {
192   MS_ASSERT(input_tensor != nullptr);
193   if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
194     MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
195     return RET_ERROR;
196   }
197   if (input_tensor->quant_params().empty()) {
198     MS_LOG(ERROR) << "No quant param.";
199     return RET_ERROR;
200   }
201   if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat32) {
202     auto new_const_data = DequantData<int16_t, float>(input_tensor, channel_first);
203     input_tensor->FreeData();
204     input_tensor->set_data(new_const_data);
205     input_tensor->set_own_data(true);
206     input_tensor->set_data_type(dst_data_type);
207   } else if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat16) {
208 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
209     auto new_const_data = DequantData<int16_t, float16_t>(input_tensor, channel_first);
210     input_tensor->FreeData();
211     input_tensor->set_data(new_const_data);
212     input_tensor->set_own_data(true);
213     input_tensor->set_data_type(dst_data_type);
214 #else
215     MS_LOG(ERROR) << "Float16 is not supported";
216     return RET_NOT_SUPPORT;
217 #endif
218   } else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat32) {
219     auto new_const_data = DequantData<int8_t, float>(input_tensor, channel_first);
220     input_tensor->FreeData();
221     input_tensor->set_data(new_const_data);
222     input_tensor->set_own_data(true);
223     input_tensor->set_data_type(dst_data_type);
224   } else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat16) {
225 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
226     auto new_const_data = DequantData<int8_t, float16_t>(input_tensor, channel_first);
227     input_tensor->FreeData();
228     input_tensor->set_data(new_const_data);
229     input_tensor->set_own_data(true);
230     input_tensor->set_data_type(dst_data_type);
231 #else
232     MS_LOG(ERROR) << "Float16 is not supported";
233     return RET_NOT_SUPPORT;
234 #endif
235   } else {
236     MS_LOG(ERROR) << "Unsupported dequant from data_type(" << (input_tensor->data_type()) << ") to data_type("
237                   << dst_data_type << ")";
238     return RET_NOT_SUPPORT;
239   }
240   return RET_OK;
241 }
242 
DecodeHuffmanCode(const schema::Tensor & src_tensor,lite::Tensor * dst_tensor)243 int WeightDecoder::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
244   MS_ASSERT(dst_tensor != nullptr);
245   if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) {
246     return RET_NO_CHANGE;
247   }
248   if (src_tensor.data() == nullptr) {
249     return RET_NO_CHANGE;
250   }
251   auto data = reinterpret_cast<const char *>(src_tensor.data()->data());
252   if (data == nullptr) {
253     return RET_NO_CHANGE;
254   }
255   std::string encode_str(data, src_tensor.data()->size());
256   dst_tensor->FreeData();
257   dst_tensor->set_data(nullptr);
258   auto ret = dst_tensor->MallocData();
259   if (ret != RET_OK) {
260     MS_LOG(ERROR) << "Malloc tensor data failed";
261     return RET_NULL_PTR;
262   }
263   auto dst_data = dst_tensor->data();
264   MS_ASSERT(dst_data != nullptr);
265   ret = HuffmanDecode::DoHuffmanDecode(encode_str, dst_data, dst_tensor->Size());
266   if (ret != RET_OK) {
267     MS_LOG(ERROR) << "DoHuffmanDecode failed.";
268     return ret;
269   }
270   return RET_OK;
271 }
272 
UnPackToInt(const schema::Tensor & src_tensor,lite::Tensor * dst_tensor)273 int WeightDecoder::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
274   MS_ASSERT(dst_tensor != nullptr);
275   auto quant_params = src_tensor.quantParams();
276   if (quant_params == nullptr || quant_params->size() == 0) {
277     return RET_NO_CHANGE;
278   }
279   auto quant_param = quant_params->Get(0);
280   if (quant_param == nullptr) {
281     return RET_NO_CHANGE;
282   }
283   auto dst_data = dst_tensor->data();
284   if (dst_data != nullptr) {
285     MS_LOG(ERROR) << "lite Tensor has already malloced data";
286     return RET_ERROR;
287   }
288   auto ret = dst_tensor->MallocData();
289   if (ret != RET_OK) {
290     MS_LOG(ERROR) << "Malloc tensor data failed";
291     return RET_NULL_PTR;
292   }
293   dst_data = dst_tensor->data();
294   int origin_bit = quant_param->numBits();
295   if (origin_bit < kBitNum8 && origin_bit >= kBitNum1) {
296     UnPackUtil<int8_t, uint8_t>(&src_tensor, origin_bit, dst_data);
297     return RET_OK;
298   } else if (origin_bit < kBitNum16 && origin_bit > kBitNum8) {
299     UnPackUtil<int16_t, uint16_t>(&src_tensor, origin_bit, dst_data);
300     return RET_OK;
301   } else {
302     MS_LOG(ERROR) << "Unsupported bit number: " << origin_bit;
303     return RET_NOT_SUPPORT;
304   }
305 }
306 
UnPack(const schema::Tensor & src_tensor,lite::Tensor * dst_tensor)307 int WeightDecoder::UnPack(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
308   STATUS ret = RET_OK;
309   if (src_tensor.enableHuffmanCode()) {
310     ret = WeightDecoder::DecodeHuffmanCode(src_tensor, dst_tensor);
311     if (ret != RET_OK && ret != RET_NO_CHANGE) {
312       MS_LOG(ERROR) << "Decode huffman code failed: " << ret;
313     }
314   } else {
315     ret = WeightDecoder::UnPackToInt(src_tensor, dst_tensor);
316     if (ret != RET_OK && ret != RET_NO_CHANGE) {
317       MS_LOG(ERROR) << "Unpack to int8 failed: " << ret;
318     }
319   }
320   return ret;
321 }
322 
DequantNode(OpParameter * op_parameter,const std::vector<Tensor * > & in_tensors,TypeId dst_data_type)323 int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors,
324                                TypeId dst_data_type) {
325   if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) {
326     return RET_OK;
327   }
328   int index = 0;
329   for (auto &tensor : in_tensors) {
330     MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
331     auto channel_first = IsChannelFirst(index++, op_parameter);
332     auto ret = WeightDecoder::DequantTensor(tensor, channel_first, dst_data_type);
333     if (ret != RET_OK && ret != RET_NO_CHANGE) {
334       MS_LOG(DEBUG) << "Dequant tensor failed";
335       return RET_ERROR;
336     }
337   }
338   return RET_OK;
339 }
340 
DequantTensor(Tensor * tensor,bool channel_first,TypeId dst_data_type)341 int WeightDecoder::DequantTensor(Tensor *tensor, bool channel_first, TypeId dst_data_type) {
342   MS_ASSERT(tensor != nullptr);
343   if (!tensor->IsConst() ||
344       !(dst_data_type == TypeId::kNumberTypeFloat32 || dst_data_type == TypeId::kNumberTypeFloat16)) {
345     return RET_NO_CHANGE;
346   }
347   bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited &&
348                       (tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16);
349   if (!need_dequant) {
350     return RET_NO_CHANGE;
351   }
352   auto ret = WeightDecoder::DequantWeight(tensor, channel_first, dst_data_type);
353   if (ret != RET_OK) {
354     MS_LOG(ERROR) << "Dequant data failed: " << ret;
355     return ret;
356   }
357   return RET_OK;
358 }
359 }  // namespace mindspore::lite
360