1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/quantizer/fse_encoder.h"
18 #include <cstdint>
19 #include <algorithm>
20 #include <cmath>
21 #include "mindspore/core/ir/dtype/type_id.h"
22 #include "src/common/log_adapter.h"
23 #include "src/common/log_util.h"
24 #include "include/errorcode.h"
25
26 namespace mindspore::lite::quant {
27 namespace {
28 constexpr int kInt32Mask = 31;
29 constexpr int kInt16 = 16;
30 constexpr int kFseTableExtendSize = 3;
31 constexpr int kFrenqTableExtendSize = 2;
32 constexpr int kAlignSize = 8;
33 constexpr float kUpRoundOffSet = 0.5;
34 } // namespace
35 // The function gives the index of most import `1` in the binary representation.
36 // e.g. for the number 00100 it gives 2.
fse_count_bits(int32_t x)37 int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ kInt32Mask; }
38
FSECreateStatesForEncoding(uint32_t * frequency,int frequency_count,int table_log,uint32_t * delta_bit_count,int16_t * delta_state,uint16_t * coding_table,uint16_t * symbol_table)39 int FSEEncoder::FSECreateStatesForEncoding(uint32_t *frequency, int frequency_count, int table_log,
40 uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
41 uint16_t *symbol_table) {
42 CHECK_NULL_RETURN(frequency);
43 CHECK_NULL_RETURN(delta_bit_count);
44 CHECK_NULL_RETURN(delta_state);
45 CHECK_NULL_RETURN(symbol_table);
46 CHECK_NULL_RETURN(coding_table);
47 const int tablesize = 1 << table_log;
48 int tablemask = tablesize - 1;
49 int step = ((tablesize >> 1) + (tablesize >> kFseTableExtendSize) + kFseTableExtendSize);
50 int pos = 0;
51 // Separate the same symbols, coding will be better if the same characters are distributed evenly across the table.
52 for (int sym = 0; sym < frequency_count; sym++) {
53 for (uint32_t i = 0; i < frequency[sym]; i++) {
54 symbol_table[pos] = sym;
55 pos = (pos + step) & tablemask;
56 while (pos > tablemask) pos = (pos + step) & tablemask;
57 }
58 }
59 if (pos != 0) {
60 return RET_ERROR;
61 }
62
63 std::vector<uint32_t> cfreqs(frequency_count + kFrenqTableExtendSize);
64 cfreqs[0] = 0;
65 for (int i = 1; i < frequency_count + 1; i++) {
66 cfreqs[i] = cfreqs[i - 1] + frequency[i - 1];
67 }
68 cfreqs[frequency_count + 1] = cfreqs[frequency_count] + 1;
69 for (int i = 0; i < tablesize; i++) {
70 uint16_t sym = symbol_table[i];
71 coding_table[cfreqs[sym]] = tablesize + i;
72 cfreqs[sym] += 1;
73 }
74
75 int total = 0;
76 for (int sym = 0; sym < frequency_count; sym++) {
77 if (frequency[sym] >= kFrenqTableExtendSize) {
78 int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
79 int min_state_plus = frequency[sym] << max_bits_out;
80 delta_bit_count[sym] = (max_bits_out << kInt16) - min_state_plus;
81 delta_state[sym] = total - frequency[sym];
82 total += frequency[sym];
83 } else {
84 // we assume minimum `frequency` is 1
85 delta_bit_count[sym] = (table_log << kInt16) - (1 << table_log);
86 delta_state[sym] = total - 1;
87 total++;
88 }
89 }
90 return RET_OK;
91 }
92
ConvertTensor2Quant(schema::TensorT * tensor_input,FSEQuant * quants)93 int ConvertTensor2Quant(schema::TensorT *tensor_input, FSEQuant *quants) {
94 CHECK_NULL_RETURN(tensor_input);
95 CHECK_NULL_RETURN(quants);
96 std::vector<int16_t> dequants;
97 for (size_t i = 0; i < tensor_input->data.size() / sizeof(int16_t); ++i) {
98 auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
99 dequants.push_back(data);
100 }
101
102 int qmin = *min_element(dequants.begin(), dequants.end());
103 int qmax = *max_element(dequants.begin(), dequants.end());
104 int uncompressed_frequency_count = qmax - qmin + 1;
105 std::vector<int> uncompressed_frequency(uncompressed_frequency_count);
106 for (int i = 0; i < uncompressed_frequency_count; i++) {
107 uncompressed_frequency[i] = 0;
108 }
109 for (size_t i = 0; i < tensor_input->data.size() / sizeof(int16_t); i++) {
110 auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
111 int q = data - qmin;
112 uncompressed_frequency[q] += 1;
113 }
114
115 std::vector<uint16_t> uncompressed_freqs_to_compressed_sym(uncompressed_frequency_count);
116 int sym = 0;
117 for (int i = 0; i < uncompressed_frequency_count; i++) {
118 if (uncompressed_frequency[i] != 0) {
119 if (sym >= MAX_SYMS) {
120 return 1; // too many symbols!
121 }
122 uncompressed_freqs_to_compressed_sym[i] = sym;
123 quants->frequency[sym] = uncompressed_frequency[i];
124 quants->centroids[sym] =
125 tensor_input->quantParams.front()->varCorr *
126 (tensor_input->quantParams.front()->scale - tensor_input->quantParams.front()->zeroPoint) * (i + qmin) +
127 tensor_input->quantParams.front()->meanCorr;
128 sym++;
129 }
130 }
131 quants->size = sym;
132 quants->symbol_table_count = tensor_input->data.size() / sizeof(int16_t);
133 quants->symbol_table = static_cast<uint16_t *>(malloc(quants->symbol_table_count * sizeof(uint16_t)));
134 if (quants->symbol_table == nullptr) {
135 MS_LOG(ERROR) << "malloc memory failed.";
136 return RET_ERROR;
137 }
138 for (int i = 0; i < quants->symbol_table_count; i++) {
139 auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
140 int q = data - qmin;
141 sym = uncompressed_freqs_to_compressed_sym[q];
142 quants->symbol_table[i] = sym;
143 }
144 return RET_OK;
145 }
146
Compress(schema::TensorT * tensor_input)147 int FSEEncoder::Compress(schema::TensorT *tensor_input) {
148 MS_ASSERT(tensor_input);
149 int table_log = 0;
150 FSEQuant fse_quant;
151 auto ret = ConvertTensor2Quant(tensor_input, &fse_quant);
152 if (ret != RET_OK) {
153 MS_LOG(ERROR) << "Convert tensor 2 quant failed.";
154 return ret;
155 }
156 ret = NormalizeFrequency(&fse_quant, &table_log);
157 if (ret != RET_OK) {
158 MS_LOG(ERROR) << "Normalize frequency failed.";
159 return ret;
160 }
161 BitStream bs;
162 ret = bs.Create(kInt16 * fse_quant.symbol_table_count);
163 if (ret != RET_OK) {
164 MS_LOG(ERROR) << "BitStream Create failed.";
165 free(fse_quant.symbol_table);
166 return ret;
167 }
168 ret = FSEEncode(&bs, fse_quant.symbol_table, fse_quant.symbol_table_count, fse_quant.frequency, fse_quant.size,
169 table_log);
170 if (ret != RET_OK) {
171 MS_LOG(ERROR) << "FSE Encode failed.";
172 free(fse_quant.symbol_table);
173 return ret;
174 }
175 bs.Flush();
176 // Serializing to out:
177 ret = SerializingToOut(tensor_input, &bs, fse_quant, table_log);
178 if (ret != RET_OK) {
179 MS_LOG(ERROR) << "Serializing To Out failed.";
180 free(fse_quant.symbol_table);
181 return ret;
182 }
183 bs.Free();
184 free(fse_quant.symbol_table);
185 return RET_OK;
186 }
187
FSEEncodeSymbolGetNewState(BitStream * bs,uint16_t sym,uint16_t state,const uint32_t * delta_bit_count,const int16_t * delta_state,uint16_t * coding_table)188 uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state,
189 const uint32_t *delta_bit_count, const int16_t *delta_state,
190 uint16_t *coding_table) {
191 MS_ASSERT(bs != nullptr);
192 MS_ASSERT(delta_bit_count != nullptr);
193 MS_ASSERT(delta_state != nullptr);
194 MS_ASSERT(coding_table != nullptr);
195 // It is to determine the number of bits to flush.
196 // This is basically one of 2 values, n or n+1, depending on state crossing a threshold.
197 uint8_t bits_out = (state + delta_bit_count[sym]) >> kInt16;
198 bs->Push(state, bits_out);
199 // subrangeID = state >> nbBitsOut
200 return coding_table[(state >> bits_out) + delta_state[sym]];
201 }
202
GetMaxIndex(const uint32_t * arr,int arr_count)203 int GetMaxIndex(const uint32_t *arr, int arr_count) {
204 MS_ASSERT(arr != nullptr);
205 float max = -INFINITY;
206 int index = -1;
207 for (int i = 0; i < arr_count; i++) {
208 if (arr[i] > max) {
209 max = arr[i];
210 index = i;
211 }
212 }
213 return index;
214 }
215
NormalizeFrequency(FSEQuant * q,int * table_log)216 int FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
217 CHECK_NULL_RETURN(q);
218 CHECK_NULL_RETURN(table_log);
219 // The higher the number, the more accurate we'll be to the shannon entropy,
220 // but also the larger the table, so `+3` is a good compromise.
221 *table_log = std::min(MAX_TABLE_LOG, (fse_count_bits((uint32_t)q->size) + kFseTableExtendSize));
222 const int new_table_size = 1 << (*table_log);
223 int curr_table_size = 0;
224 for (int i = 0; i < q->size; i++) {
225 curr_table_size += q->frequency[i];
226 }
227
228 if (curr_table_size == 0) {
229 MS_LOG(ERROR) << "curr_table_size is 0";
230 return RET_ERROR;
231 }
232 // normalize
233 int updated_table_size = 0;
234 float rat = (static_cast<float>(new_table_size)) / curr_table_size;
235 for (int i = 0; i < q->size; i++) {
236 q->frequency[i] = std::max(1, static_cast<int>(floorf(kUpRoundOffSet + rat * q->frequency[i])));
237 updated_table_size += q->frequency[i];
238 }
239
240 // If the sum of the symbol frequencies is not equal to the power of two (almost always),
241 // then the frequencies need to be normalized-they must be proportionally reduced (or increased) so that the power of
242 // two is obtained in total.
243 // shrink
244 while (updated_table_size > new_table_size) {
245 int max_ix = GetMaxIndex(q->frequency, q->size);
246 if (max_ix < 0 || max_ix > MAX_SYMS) {
247 MS_LOG(ERROR) << "max_ix is invalid.";
248 return RET_ERROR;
249 }
250 q->frequency[max_ix]--;
251 updated_table_size--;
252 }
253
254 // grow
255 if (updated_table_size < new_table_size) {
256 int max_ix = GetMaxIndex(q->frequency, q->size);
257 if (max_ix < 0 || max_ix >= MAX_SYMS) {
258 MS_LOG(ERROR) << "max_ix is invalid.";
259 return RET_ERROR;
260 }
261 q->frequency[max_ix] += new_table_size - updated_table_size;
262 }
263 return RET_OK;
264 }
265
266 // Encoding is therefore just a repeat of this process :
267 // - get Symbol to encode
268 // - look at current state value
269 // - determine nbBits, flush them
270 // - determine sub-Range Id
271 // - look for Symbol position of same Id : you get your next state
FSEEncode(BitStream * bs,const uint16_t * data,int data_count,uint32_t * frequency,int frequency_count,int table_log)272 int FSEEncoder::FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency, int frequency_count,
273 int table_log) {
274 MS_ASSERT(bs != nullptr);
275 MS_ASSERT(data != nullptr);
276 MS_ASSERT(frequency != nullptr);
277 int table_size = 1 << table_log;
278 // symbolTT.deltaNbBits stores a value which, when added with state,
279 // makes the result of >> 16 produces either n or n+1, as required.
280 std::vector<uint32_t> delta_number_bits(frequency_count);
281 // symbolTT.deltaFindState provides the offset to find the correct segment into the table.
282 std::vector<int16_t> delta_find_state(frequency_count);
283 // nextStateTable with symbol
284 std::vector<uint16_t> coding_table(table_size);
285 // position with symbol
286 std::vector<uint16_t> symtable(table_size);
287 int ret = FSECreateStatesForEncoding(frequency, frequency_count, table_log, delta_number_bits.data(),
288 delta_find_state.data(), coding_table.data(), symtable.data());
289 if (ret != RET_OK) {
290 MS_LOG(ERROR) << "Create states table for encoding failed.";
291 return ret;
292 }
293 uint16_t state = table_size;
294 // The results of the 1st symbol encoding is not flushed to the bitstream,
295 // It is just to get a valid 1 st state.
296 state = FSEEncodeSymbolGetNewState(bs, data[0], state, delta_number_bits.data(), delta_find_state.data(),
297 coding_table.data());
298 bs->Empty();
299 for (int i = 0; i < data_count; i++) {
300 state = FSEEncodeSymbolGetNewState(bs, data[i], state, delta_number_bits.data(), delta_find_state.data(),
301 coding_table.data());
302 }
303 bs->Push(state - table_size, table_log);
304 return ret;
305 }
306
SerializingToTensor(schema::TensorT * tensor_input,BitStream * bs,const FSEQuant & fse_quant,int table_log,uint8_t * out8,size_t max_size,size_t * out_size)307 int FSEEncoder::SerializingToTensor(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
308 int table_log, uint8_t *out8, size_t max_size, size_t *out_size) {
309 MSLITE_CHECK_PTR(tensor_input);
310 MSLITE_CHECK_PTR(bs);
311 MSLITE_CHECK_PTR(out_size);
312 CHECK_MALLOC_RES(out8, RET_ERROR);
313 int offset = 0;
314 *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.size;
315 offset += sizeof(uint16_t);
316 if (offset + sizeof(uint16_t) > max_size) {
317 MS_LOG(ERROR) << "offset over max size"
318 << " offset:" << offset << " max_size:" << max_size;
319 return RET_ERROR;
320 }
321 *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
322 offset += sizeof(uint16_t);
323 int chunksc = bs->GetCurrChunkIndex() + sizeof(uint16_t);
324 if (offset + sizeof(uint32_t) > max_size) {
325 MS_LOG(ERROR) << "offset over max size"
326 << " offset:" << offset << " max_size:" << max_size;
327 return RET_ERROR;
328 }
329 *(reinterpret_cast<uint32_t *>(&out8[offset])) = (uint32_t)chunksc;
330 offset += sizeof(uint32_t);
331 for (int j = 0; j < fse_quant.size; j++) {
332 if (offset + sizeof(uint32_t) > max_size) {
333 MS_LOG(ERROR) << "offset over max size"
334 << " offset:" << offset << " max_size:" << max_size;
335 return RET_ERROR;
336 }
337 *(reinterpret_cast<uint32_t *>(&out8[offset])) = (uint32_t)fse_quant.frequency[j];
338 offset += sizeof(uint32_t);
339 }
340 while (offset % kAlignSize != 0) {
341 if (offset + sizeof(uint16_t) > max_size) {
342 MS_LOG(ERROR) << "offset over max size"
343 << " offset:" << offset << " max_size:" << max_size;
344 return RET_ERROR;
345 }
346 *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
347 offset += sizeof(uint16_t);
348 }
349 for (int j = 0; j < fse_quant.size; j++) {
350 if (offset + sizeof(float) > max_size) {
351 MS_LOG(ERROR) << "offset over max size"
352 << " offset:" << offset << " max_size:" << max_size;
353 return RET_ERROR;
354 }
355 *(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
356 offset += sizeof(float);
357 }
358 while (offset % kAlignSize != 0) {
359 if (offset + sizeof(uint16_t) > max_size) {
360 MS_LOG(ERROR) << "offset over max size"
361 << " offset:" << offset << " max_size:" << max_size;
362 return RET_ERROR;
363 }
364 *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
365 offset += sizeof(uint16_t);
366 }
367 for (int j = 0; j < bs->GetCurrChunkIndex() + 1; j++) {
368 if (offset + sizeof(uint64_t) > max_size) {
369 MS_LOG(ERROR) << "offset over max size"
370 << " offset:" << offset << " max_size:" << max_size;
371 return RET_ERROR;
372 }
373 *(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetChunks()[j];
374 offset += sizeof(uint64_t);
375 }
376 if (offset + sizeof(uint64_t) > max_size) {
377 MS_LOG(ERROR) << "offset over max size"
378 << " offset:" << offset << " max_size:" << max_size;
379 return RET_ERROR;
380 }
381 *(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetCurrChunk();
382 offset += sizeof(uint64_t);
383 if (offset + sizeof(uint8_t) > max_size) {
384 MS_LOG(ERROR) << "offset over max size"
385 << " offset:" << offset << " max_size:" << max_size;
386 return RET_ERROR;
387 }
388 *(reinterpret_cast<uint8_t *>(&out8[offset])) = (uint8_t)bs->GetCurrBitCount();
389 offset += sizeof(uint8_t);
390 if (static_cast<int>(offset) > static_cast<int>(tensor_input->data.size())) {
391 MS_LOG(ERROR) << "Too many symbol.";
392 return RET_ERROR;
393 }
394 *out_size = offset;
395 return RET_OK;
396 }
397
SerializingToOut(schema::TensorT * tensor_input,BitStream * bs,const FSEQuant & fse_quant,int table_log)398 int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
399 int table_log) {
400 MSLITE_CHECK_PTR(tensor_input);
401 MSLITE_CHECK_PTR(bs);
402 const int extend_size = 2;
403 auto max_size = tensor_input->data.size() * extend_size;
404 auto *out8 = static_cast<uint8_t *>(malloc(max_size));
405 MSLITE_CHECK_PTR(out8);
406 size_t out_size = 0;
407 auto ret = SerializingToTensor(tensor_input, bs, fse_quant, table_log, out8, max_size, &out_size);
408 if (ret != RET_OK) {
409 MS_LOG(ERROR) << "Store data to tensor failed.";
410 free(out8);
411 return ret;
412 }
413 tensor_input->data.resize(out_size);
414 MSLITE_CHECK_PTR(tensor_input->data.data());
415 if (memcpy_s(tensor_input->data.data(), out_size, out8, out_size) != EOK) {
416 MS_LOG(ERROR) << "memcpy failed.";
417 free(out8);
418 return RET_ERROR;
419 }
420 tensor_input->quantParams.clear();
421 tensor_input->weightQunatCompressType = schema::WeightQunatCompressType_FSE;
422 tensor_input->dataType = TypeId::kNumberTypeFloat32;
423 free(out8);
424 return RET_OK;
425 }
426 } // namespace mindspore::lite::quant
427