1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <cmath>
17 #include <string>
18 #include "src/weight_decoder.h"
19 #include "src/huffman_decode.h"
20
21 namespace mindspore::lite {
22 constexpr int kBit8 = 8;
23 constexpr int kBit32 = 32;
StringToBitVector(const std::string & str)24 std::vector<bool> StringToBitVector(const std::string &str) {
25 std::vector<bool> vec(str.size() * kBit8);
26 size_t index = 0;
27 for (auto ch : str) {
28 for (size_t shift = kBit8; shift > 0; shift--) {
29 vec[index++] = (ch >> static_cast<size_t>(shift - 1)) & 0x1;
30 }
31 }
32 return vec;
33 }
34
IndexingDecompress(const schema::Tensor & src_tensor,Tensor * dst_tensor)35 STATUS IndexingDecompress(const schema::Tensor &src_tensor, Tensor *dst_tensor) {
36 MS_LOG(DEBUG) << "un-index weight";
37 MS_CHECK_TRUE_MSG(src_tensor.quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
38 MS_CHECK_TRUE_MSG((*src_tensor.quantParams()).size() > 0, RET_ERROR, "quant params size need bigger than 0");
39 MS_CHECK_TRUE_MSG(src_tensor.quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
40 auto bit_num = src_tensor.quantParams()->Get(0)->numBits();
41
42 std::string str(reinterpret_cast<const char *>(src_tensor.data()->data()), src_tensor.data()->size());
43 auto bit_vec = StringToBitVector(str);
44 size_t index = 0;
45 // parse unique_value_cnt
46 size_t unique_value_cnt = 0;
47 for (int i = 0; i < bit_num; i++) {
48 bool bit = bit_vec[index++];
49 unique_value_cnt |= bit << static_cast<size_t>((bit_num - i - 1));
50 }
51 if (unique_value_cnt == 0) {
52 unique_value_cnt = 1 << bit_num;
53 }
54 // parse unique_value_set
55 std::vector<int> unique_values;
56 for (size_t i = 0; i < unique_value_cnt; i++) {
57 int unique_value = 0;
58 for (int j = 0; j < bit_num; j++) {
59 bool bit = bit_vec[index++];
60 unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
61 }
62 // unsigned to signed
63 unique_values.push_back(unique_value - (1 << static_cast<size_t>((bit_num - 1))));
64 }
65 // parse index
66 std::vector<size_t> unique_value_index_vec;
67 auto elem_cnt = dst_tensor->ElementsNum();
68 size_t unique_value_bit = ceil(log2(unique_value_cnt));
69 for (int i = 0; i < elem_cnt; i++) {
70 size_t unique_value_index = 0;
71 for (size_t j = 0; j < unique_value_bit; j++) {
72 bool bit = bit_vec[index++];
73 unique_value_index |= bit << (static_cast<size_t>(unique_value_bit - j - 1));
74 }
75 unique_value_index_vec.push_back(unique_value_index);
76 }
77
78 if (dst_tensor->data() != nullptr) {
79 MS_LOG(ERROR) << "data_c not null";
80 return RET_ERROR;
81 }
82 auto ret = dst_tensor->MallocData();
83 if (ret != RET_OK) {
84 MS_LOG(ERROR) << "Malloc tensor data failed";
85 return RET_NULL_PTR;
86 }
87 auto dst_data = dst_tensor->data();
88 if (bit_num <= kBit8) {
89 ret = UnIndexTensorData<int8_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
90 } else {
91 ret = UnIndexTensorData<int16_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
92 }
93 if (ret != RET_OK) {
94 MS_LOG(ERROR) << "UnIndexTensorData error";
95 return RET_ERROR;
96 }
97 return RET_OK;
98 }
99
SparseDecompress(const schema::Tensor & src_tensor,Tensor * dst_tensor)100 STATUS SparseDecompress(const schema::Tensor &src_tensor, Tensor *dst_tensor) {
101 MS_LOG(DEBUG) << "un-sparse weight";
102 MS_CHECK_TRUE_MSG(src_tensor.quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
103 MS_CHECK_TRUE_MSG((*src_tensor.quantParams()).size() > 0, RET_ERROR, "quant params size need bigger than 0");
104 MS_CHECK_TRUE_MSG(src_tensor.quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
105 size_t bit_num = src_tensor.quantParams()->Get(0)->numBits();
106
107 std::string str(reinterpret_cast<const char *>(src_tensor.data()->data()), src_tensor.data()->size());
108 auto bit_vec = StringToBitVector(str);
109 size_t index = 0;
110 // parse coor_best_bit
111 size_t coor_best_bit = 0;
112 for (size_t i = 0; i < kBit8; i++) {
113 bool bit = bit_vec[index++];
114 coor_best_bit |= bit << static_cast<size_t>((kBit8 - i - 1));
115 }
116 // parse nz_cnt
117 size_t nz_cnt = 0;
118 for (size_t i = 0; i < kBit32; i++) {
119 bool bit = bit_vec[index++];
120 nz_cnt |= bit << static_cast<size_t>((kBit32 - i - 1));
121 }
122 // parse unique_value cnt
123 size_t unique_value_cnt = 0;
124 for (size_t i = 0; i < bit_num; i++) {
125 bool bit = bit_vec[index++];
126 unique_value_cnt |= bit << static_cast<size_t>((bit_num - i - 1));
127 }
128 if (unique_value_cnt == 0) {
129 unique_value_cnt = 1 << bit_num;
130 }
131 // parse unique_values
132 std::vector<int> unique_values;
133 for (size_t i = 0; i < unique_value_cnt; i++) {
134 int unique_value = 0;
135 for (size_t j = 0; j < bit_num; j++) {
136 bool bit = bit_vec[index++];
137 unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
138 }
139 // unsigned to signed
140 unique_values.push_back(unique_value - (1 << static_cast<size_t>((bit_num - 1))));
141 }
142 // parse index
143 std::vector<size_t> unique_value_index_vec;
144 auto elem_cnt = dst_tensor->ElementsNum();
145 size_t unique_value_bit = static_cast<size_t>(ceil(log2(unique_value_cnt)));
146 for (size_t i = 0; i < nz_cnt; i++) {
147 size_t unique_value_index = 0;
148 for (size_t j = 0; j < unique_value_bit; j++) {
149 bool bit = bit_vec[index++];
150 unique_value_index |= bit << (unique_value_bit - j - 1);
151 }
152 unique_value_index_vec.push_back(unique_value_index);
153 }
154
155 // parse coors
156 std::vector<size_t> coor_vec;
157 for (size_t i = 0; i < nz_cnt; i++) {
158 size_t coor = 0;
159 for (size_t j = 0; j < coor_best_bit; j++) {
160 bool bit = bit_vec[index++];
161 coor |= bit << static_cast<size_t>((coor_best_bit - j - 1));
162 }
163 coor_vec.push_back(coor);
164 }
165
166 if (dst_tensor->data() != nullptr) {
167 MS_LOG(ERROR) << "data_c not null";
168 return RET_ERROR;
169 }
170 auto ret = dst_tensor->MallocData();
171 if (ret != RET_OK) {
172 MS_LOG(ERROR) << "Malloc tensor data failed";
173 return RET_NULL_PTR;
174 }
175 auto dst_data = dst_tensor->data();
176
177 if (bit_num <= kBit8) {
178 ret = UnSparseTensorData<int8_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.quantParams(),
179 elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
180 } else {
181 ret = UnSparseTensorData<int16_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.quantParams(),
182 elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
183 }
184 if (ret != RET_OK) {
185 MS_LOG(ERROR) << "UnSparseTensorData error";
186 return RET_ERROR;
187 }
188 return RET_OK;
189 }
190
DequantWeight(lite::Tensor * input_tensor,bool channel_first,TypeId dst_data_type)191 int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) {
192 MS_ASSERT(input_tensor != nullptr);
193 if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
194 MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
195 return RET_ERROR;
196 }
197 if (input_tensor->quant_params().empty()) {
198 MS_LOG(ERROR) << "No quant param.";
199 return RET_ERROR;
200 }
201 if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat32) {
202 auto new_const_data = DequantData<int16_t, float>(input_tensor, channel_first);
203 input_tensor->FreeData();
204 input_tensor->set_data(new_const_data);
205 input_tensor->set_own_data(true);
206 input_tensor->set_data_type(dst_data_type);
207 } else if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat16) {
208 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
209 auto new_const_data = DequantData<int16_t, float16_t>(input_tensor, channel_first);
210 input_tensor->FreeData();
211 input_tensor->set_data(new_const_data);
212 input_tensor->set_own_data(true);
213 input_tensor->set_data_type(dst_data_type);
214 #else
215 MS_LOG(ERROR) << "Float16 is not supported";
216 return RET_NOT_SUPPORT;
217 #endif
218 } else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat32) {
219 auto new_const_data = DequantData<int8_t, float>(input_tensor, channel_first);
220 input_tensor->FreeData();
221 input_tensor->set_data(new_const_data);
222 input_tensor->set_own_data(true);
223 input_tensor->set_data_type(dst_data_type);
224 } else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat16) {
225 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
226 auto new_const_data = DequantData<int8_t, float16_t>(input_tensor, channel_first);
227 input_tensor->FreeData();
228 input_tensor->set_data(new_const_data);
229 input_tensor->set_own_data(true);
230 input_tensor->set_data_type(dst_data_type);
231 #else
232 MS_LOG(ERROR) << "Float16 is not supported";
233 return RET_NOT_SUPPORT;
234 #endif
235 } else {
236 MS_LOG(ERROR) << "Unsupported dequant from data_type(" << (input_tensor->data_type()) << ") to data_type("
237 << dst_data_type << ")";
238 return RET_NOT_SUPPORT;
239 }
240 return RET_OK;
241 }
242
DecodeHuffmanCode(const schema::Tensor & src_tensor,lite::Tensor * dst_tensor)243 int WeightDecoder::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
244 MS_ASSERT(dst_tensor != nullptr);
245 if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) {
246 return RET_NO_CHANGE;
247 }
248 if (src_tensor.data() == nullptr) {
249 return RET_NO_CHANGE;
250 }
251 auto data = reinterpret_cast<const char *>(src_tensor.data()->data());
252 if (data == nullptr) {
253 return RET_NO_CHANGE;
254 }
255 std::string encode_str(data, src_tensor.data()->size());
256 dst_tensor->FreeData();
257 dst_tensor->set_data(nullptr);
258 auto ret = dst_tensor->MallocData();
259 if (ret != RET_OK) {
260 MS_LOG(ERROR) << "Malloc tensor data failed";
261 return RET_NULL_PTR;
262 }
263 auto dst_data = dst_tensor->data();
264 MS_ASSERT(dst_data != nullptr);
265 ret = HuffmanDecode::DoHuffmanDecode(encode_str, dst_data, dst_tensor->Size());
266 if (ret != RET_OK) {
267 MS_LOG(ERROR) << "DoHuffmanDecode failed.";
268 return ret;
269 }
270 return RET_OK;
271 }
272
UnPackToInt(const schema::Tensor & src_tensor,lite::Tensor * dst_tensor)273 int WeightDecoder::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
274 MS_ASSERT(dst_tensor != nullptr);
275 auto quant_params = src_tensor.quantParams();
276 if (quant_params == nullptr || quant_params->size() == 0) {
277 return RET_NO_CHANGE;
278 }
279 auto quant_param = quant_params->Get(0);
280 if (quant_param == nullptr) {
281 return RET_NO_CHANGE;
282 }
283 auto dst_data = dst_tensor->data();
284 if (dst_data != nullptr) {
285 MS_LOG(ERROR) << "lite Tensor has already malloced data";
286 return RET_ERROR;
287 }
288 auto ret = dst_tensor->MallocData();
289 if (ret != RET_OK) {
290 MS_LOG(ERROR) << "Malloc tensor data failed";
291 return RET_NULL_PTR;
292 }
293 dst_data = dst_tensor->data();
294 int origin_bit = quant_param->numBits();
295 if (origin_bit < kBitNum8 && origin_bit >= kBitNum1) {
296 UnPackUtil<int8_t, uint8_t>(&src_tensor, origin_bit, dst_data);
297 return RET_OK;
298 } else if (origin_bit < kBitNum16 && origin_bit > kBitNum8) {
299 UnPackUtil<int16_t, uint16_t>(&src_tensor, origin_bit, dst_data);
300 return RET_OK;
301 } else {
302 MS_LOG(ERROR) << "Unsupported bit number: " << origin_bit;
303 return RET_NOT_SUPPORT;
304 }
305 }
306
UnPack(const schema::Tensor & src_tensor,lite::Tensor * dst_tensor)307 int WeightDecoder::UnPack(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
308 STATUS ret = RET_OK;
309 if (src_tensor.enableHuffmanCode()) {
310 ret = WeightDecoder::DecodeHuffmanCode(src_tensor, dst_tensor);
311 if (ret != RET_OK && ret != RET_NO_CHANGE) {
312 MS_LOG(ERROR) << "Decode huffman code failed: " << ret;
313 }
314 } else {
315 ret = WeightDecoder::UnPackToInt(src_tensor, dst_tensor);
316 if (ret != RET_OK && ret != RET_NO_CHANGE) {
317 MS_LOG(ERROR) << "Unpack to int8 failed: " << ret;
318 }
319 }
320 return ret;
321 }
322
DequantNode(OpParameter * op_parameter,const std::vector<Tensor * > & in_tensors,TypeId dst_data_type)323 int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors,
324 TypeId dst_data_type) {
325 if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) {
326 return RET_OK;
327 }
328 int index = 0;
329 for (auto &tensor : in_tensors) {
330 MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
331 auto channel_first = IsChannelFirst(index++, op_parameter);
332 auto ret = WeightDecoder::DequantTensor(tensor, channel_first, dst_data_type);
333 if (ret != RET_OK && ret != RET_NO_CHANGE) {
334 MS_LOG(DEBUG) << "Dequant tensor failed";
335 return RET_ERROR;
336 }
337 }
338 return RET_OK;
339 }
340
DequantTensor(Tensor * tensor,bool channel_first,TypeId dst_data_type)341 int WeightDecoder::DequantTensor(Tensor *tensor, bool channel_first, TypeId dst_data_type) {
342 MS_ASSERT(tensor != nullptr);
343 if (!tensor->IsConst() ||
344 !(dst_data_type == TypeId::kNumberTypeFloat32 || dst_data_type == TypeId::kNumberTypeFloat16)) {
345 return RET_NO_CHANGE;
346 }
347 bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited &&
348 (tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16);
349 if (!need_dequant) {
350 return RET_NO_CHANGE;
351 }
352 auto ret = WeightDecoder::DequantWeight(tensor, channel_first, dst_data_type);
353 if (ret != RET_OK) {
354 MS_LOG(ERROR) << "Dequant data failed: " << ret;
355 return ret;
356 }
357 return RET_OK;
358 }
359 } // namespace mindspore::lite
360