• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/quantizer/fse_encoder.h"
18 #include <cstdint>
19 #include <algorithm>
20 #include <cmath>
21 #include "mindspore/core/ir/dtype/type_id.h"
22 #include "src/common/log_adapter.h"
23 #include "src/common/log_util.h"
24 #include "include/errorcode.h"
25 
26 namespace mindspore::lite::quant {
27 namespace {
28 constexpr int kInt32Mask = 31;
29 constexpr int kInt16 = 16;
30 constexpr int kFseTableExtendSize = 3;
31 constexpr int kFrenqTableExtendSize = 2;
32 constexpr int kAlignSize = 8;
33 constexpr float kUpRoundOffSet = 0.5;
34 }  // namespace
35 // The function gives the index of most import `1` in the binary representation.
36 // e.g. for the number 00100 it gives 2.
fse_count_bits(int32_t x)37 int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ kInt32Mask; }
38 
FSECreateStatesForEncoding(uint32_t * frequency,int frequency_count,int table_log,uint32_t * delta_bit_count,int16_t * delta_state,uint16_t * coding_table,uint16_t * symbol_table)39 int FSEEncoder::FSECreateStatesForEncoding(uint32_t *frequency, int frequency_count, int table_log,
40                                            uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
41                                            uint16_t *symbol_table) {
42   CHECK_NULL_RETURN(frequency);
43   CHECK_NULL_RETURN(delta_bit_count);
44   CHECK_NULL_RETURN(delta_state);
45   CHECK_NULL_RETURN(symbol_table);
46   CHECK_NULL_RETURN(coding_table);
47   const int tablesize = 1 << table_log;
48   int tablemask = tablesize - 1;
49   int step = ((tablesize >> 1) + (tablesize >> kFseTableExtendSize) + kFseTableExtendSize);
50   int pos = 0;
51   // Separate the same symbols, coding will be better if the same characters are distributed evenly across the table.
52   for (int sym = 0; sym < frequency_count; sym++) {
53     for (uint32_t i = 0; i < frequency[sym]; i++) {
54       symbol_table[pos] = sym;
55       pos = (pos + step) & tablemask;
56       while (pos > tablemask) pos = (pos + step) & tablemask;
57     }
58   }
59   if (pos != 0) {
60     return RET_ERROR;
61   }
62 
63   std::vector<uint32_t> cfreqs(frequency_count + kFrenqTableExtendSize);
64   cfreqs[0] = 0;
65   for (int i = 1; i < frequency_count + 1; i++) {
66     cfreqs[i] = cfreqs[i - 1] + frequency[i - 1];
67   }
68   cfreqs[frequency_count + 1] = cfreqs[frequency_count] + 1;
69   for (int i = 0; i < tablesize; i++) {
70     uint16_t sym = symbol_table[i];
71     coding_table[cfreqs[sym]] = tablesize + i;
72     cfreqs[sym] += 1;
73   }
74 
75   int total = 0;
76   for (int sym = 0; sym < frequency_count; sym++) {
77     if (frequency[sym] >= kFrenqTableExtendSize) {
78       int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
79       int min_state_plus = frequency[sym] << max_bits_out;
80       delta_bit_count[sym] = (max_bits_out << kInt16) - min_state_plus;
81       delta_state[sym] = total - frequency[sym];
82       total += frequency[sym];
83     } else {
84       // we assume minimum `frequency` is 1
85       delta_bit_count[sym] = (table_log << kInt16) - (1 << table_log);
86       delta_state[sym] = total - 1;
87       total++;
88     }
89   }
90   return RET_OK;
91 }
92 
ConvertTensor2Quant(schema::TensorT * tensor_input,FSEQuant * quants)93 int ConvertTensor2Quant(schema::TensorT *tensor_input, FSEQuant *quants) {
94   CHECK_NULL_RETURN(tensor_input);
95   CHECK_NULL_RETURN(quants);
96   std::vector<int16_t> dequants;
97   for (size_t i = 0; i < tensor_input->data.size() / sizeof(int16_t); ++i) {
98     auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
99     dequants.push_back(data);
100   }
101 
102   int qmin = *min_element(dequants.begin(), dequants.end());
103   int qmax = *max_element(dequants.begin(), dequants.end());
104   int uncompressed_frequency_count = qmax - qmin + 1;
105   std::vector<int> uncompressed_frequency(uncompressed_frequency_count);
106   for (int i = 0; i < uncompressed_frequency_count; i++) {
107     uncompressed_frequency[i] = 0;
108   }
109   for (size_t i = 0; i < tensor_input->data.size() / sizeof(int16_t); i++) {
110     auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
111     int q = data - qmin;
112     uncompressed_frequency[q] += 1;
113   }
114 
115   std::vector<uint16_t> uncompressed_freqs_to_compressed_sym(uncompressed_frequency_count);
116   int sym = 0;
117   for (int i = 0; i < uncompressed_frequency_count; i++) {
118     if (uncompressed_frequency[i] != 0) {
119       if (sym >= MAX_SYMS) {
120         return 1;  // too many symbols!
121       }
122       uncompressed_freqs_to_compressed_sym[i] = sym;
123       quants->frequency[sym] = uncompressed_frequency[i];
124       quants->centroids[sym] =
125         tensor_input->quantParams.front()->varCorr *
126           (tensor_input->quantParams.front()->scale - tensor_input->quantParams.front()->zeroPoint) * (i + qmin) +
127         tensor_input->quantParams.front()->meanCorr;
128       sym++;
129     }
130   }
131   quants->size = sym;
132   quants->symbol_table_count = tensor_input->data.size() / sizeof(int16_t);
133   quants->symbol_table = static_cast<uint16_t *>(malloc(quants->symbol_table_count * sizeof(uint16_t)));
134   if (quants->symbol_table == nullptr) {
135     MS_LOG(ERROR) << "malloc memory failed.";
136     return RET_ERROR;
137   }
138   for (int i = 0; i < quants->symbol_table_count; i++) {
139     auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
140     int q = data - qmin;
141     sym = uncompressed_freqs_to_compressed_sym[q];
142     quants->symbol_table[i] = sym;
143   }
144   return RET_OK;
145 }
146 
Compress(schema::TensorT * tensor_input)147 int FSEEncoder::Compress(schema::TensorT *tensor_input) {
148   MS_ASSERT(tensor_input);
149   int table_log = 0;
150   FSEQuant fse_quant;
151   auto ret = ConvertTensor2Quant(tensor_input, &fse_quant);
152   if (ret != RET_OK) {
153     MS_LOG(ERROR) << "Convert tensor 2 quant failed.";
154     return ret;
155   }
156   ret = NormalizeFrequency(&fse_quant, &table_log);
157   if (ret != RET_OK) {
158     MS_LOG(ERROR) << "Normalize frequency failed.";
159     return ret;
160   }
161   BitStream bs;
162   ret = bs.Create(kInt16 * fse_quant.symbol_table_count);
163   if (ret != RET_OK) {
164     MS_LOG(ERROR) << "BitStream Create failed.";
165     free(fse_quant.symbol_table);
166     return ret;
167   }
168   ret = FSEEncode(&bs, fse_quant.symbol_table, fse_quant.symbol_table_count, fse_quant.frequency, fse_quant.size,
169                   table_log);
170   if (ret != RET_OK) {
171     MS_LOG(ERROR) << "FSE Encode failed.";
172     free(fse_quant.symbol_table);
173     return ret;
174   }
175   bs.Flush();
176   // Serializing to out:
177   ret = SerializingToOut(tensor_input, &bs, fse_quant, table_log);
178   if (ret != RET_OK) {
179     MS_LOG(ERROR) << "Serializing To Out failed.";
180     free(fse_quant.symbol_table);
181     return ret;
182   }
183   bs.Free();
184   free(fse_quant.symbol_table);
185   return RET_OK;
186 }
187 
FSEEncodeSymbolGetNewState(BitStream * bs,uint16_t sym,uint16_t state,const uint32_t * delta_bit_count,const int16_t * delta_state,uint16_t * coding_table)188 uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state,
189                                                 const uint32_t *delta_bit_count, const int16_t *delta_state,
190                                                 uint16_t *coding_table) {
191   MS_ASSERT(bs != nullptr);
192   MS_ASSERT(delta_bit_count != nullptr);
193   MS_ASSERT(delta_state != nullptr);
194   MS_ASSERT(coding_table != nullptr);
195   // It is to determine the number of bits to flush.
196   // This is basically one of 2 values, n or n+1, depending on state crossing a threshold.
197   uint8_t bits_out = (state + delta_bit_count[sym]) >> kInt16;
198   bs->Push(state, bits_out);
199   // subrangeID = state >> nbBitsOut
200   return coding_table[(state >> bits_out) + delta_state[sym]];
201 }
202 
GetMaxIndex(const uint32_t * arr,int arr_count)203 int GetMaxIndex(const uint32_t *arr, int arr_count) {
204   MS_ASSERT(arr != nullptr);
205   float max = -INFINITY;
206   int index = -1;
207   for (int i = 0; i < arr_count; i++) {
208     if (arr[i] > max) {
209       max = arr[i];
210       index = i;
211     }
212   }
213   return index;
214 }
215 
NormalizeFrequency(FSEQuant * q,int * table_log)216 int FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
217   CHECK_NULL_RETURN(q);
218   CHECK_NULL_RETURN(table_log);
219   // The higher the number, the more accurate we'll be to the shannon entropy,
220   // but also the larger the table, so `+3` is a good compromise.
221   *table_log = std::min(MAX_TABLE_LOG, (fse_count_bits((uint32_t)q->size) + kFseTableExtendSize));
222   const int new_table_size = 1 << (*table_log);
223   int curr_table_size = 0;
224   for (int i = 0; i < q->size; i++) {
225     curr_table_size += q->frequency[i];
226   }
227 
228   if (curr_table_size == 0) {
229     MS_LOG(ERROR) << "curr_table_size is 0";
230     return RET_ERROR;
231   }
232   // normalize
233   int updated_table_size = 0;
234   float rat = (static_cast<float>(new_table_size)) / curr_table_size;
235   for (int i = 0; i < q->size; i++) {
236     q->frequency[i] = std::max(1, static_cast<int>(floorf(kUpRoundOffSet + rat * q->frequency[i])));
237     updated_table_size += q->frequency[i];
238   }
239 
240   // If the sum of the symbol frequencies is not equal to the power of two (almost always),
241   // then the frequencies need to be normalized-they must be proportionally reduced (or increased) so that the power of
242   // two is obtained in total.
243   // shrink
244   while (updated_table_size > new_table_size) {
245     int max_ix = GetMaxIndex(q->frequency, q->size);
246     if (max_ix < 0 || max_ix > MAX_SYMS) {
247       MS_LOG(ERROR) << "max_ix is invalid.";
248       return RET_ERROR;
249     }
250     q->frequency[max_ix]--;
251     updated_table_size--;
252   }
253 
254   // grow
255   if (updated_table_size < new_table_size) {
256     int max_ix = GetMaxIndex(q->frequency, q->size);
257     if (max_ix < 0 || max_ix >= MAX_SYMS) {
258       MS_LOG(ERROR) << "max_ix is invalid.";
259       return RET_ERROR;
260     }
261     q->frequency[max_ix] += new_table_size - updated_table_size;
262   }
263   return RET_OK;
264 }
265 
266 // Encoding is therefore just a repeat of this process :
267 // - get Symbol to encode
268 // - look at current state value
269 // - determine nbBits, flush them
270 // - determine sub-Range Id
271 // - look for Symbol position of same Id : you get your next state
FSEEncode(BitStream * bs,const uint16_t * data,int data_count,uint32_t * frequency,int frequency_count,int table_log)272 int FSEEncoder::FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency, int frequency_count,
273                           int table_log) {
274   MS_ASSERT(bs != nullptr);
275   MS_ASSERT(data != nullptr);
276   MS_ASSERT(frequency != nullptr);
277   int table_size = 1 << table_log;
278   // symbolTT.deltaNbBits stores a value which, when added with state,
279   // makes the result of >> 16 produces either n or n+1, as required.
280   std::vector<uint32_t> delta_number_bits(frequency_count);
281   // symbolTT.deltaFindState provides the offset to find the correct segment into the table.
282   std::vector<int16_t> delta_find_state(frequency_count);
283   // nextStateTable with symbol
284   std::vector<uint16_t> coding_table(table_size);
285   // position with symbol
286   std::vector<uint16_t> symtable(table_size);
287   int ret = FSECreateStatesForEncoding(frequency, frequency_count, table_log, delta_number_bits.data(),
288                                        delta_find_state.data(), coding_table.data(), symtable.data());
289   if (ret != RET_OK) {
290     MS_LOG(ERROR) << "Create states table for encoding failed.";
291     return ret;
292   }
293   uint16_t state = table_size;
294   // The results of the 1st symbol encoding is not flushed to the bitstream,
295   // It is just to get a valid 1 st state.
296   state = FSEEncodeSymbolGetNewState(bs, data[0], state, delta_number_bits.data(), delta_find_state.data(),
297                                      coding_table.data());
298   bs->Empty();
299   for (int i = 0; i < data_count; i++) {
300     state = FSEEncodeSymbolGetNewState(bs, data[i], state, delta_number_bits.data(), delta_find_state.data(),
301                                        coding_table.data());
302   }
303   bs->Push(state - table_size, table_log);
304   return ret;
305 }
306 
SerializingToTensor(schema::TensorT * tensor_input,BitStream * bs,const FSEQuant & fse_quant,int table_log,uint8_t * out8,size_t max_size,size_t * out_size)307 int FSEEncoder::SerializingToTensor(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
308                                     int table_log, uint8_t *out8, size_t max_size, size_t *out_size) {
309   MSLITE_CHECK_PTR(tensor_input);
310   MSLITE_CHECK_PTR(bs);
311   MSLITE_CHECK_PTR(out_size);
312   CHECK_MALLOC_RES(out8, RET_ERROR);
313   int offset = 0;
314   *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.size;
315   offset += sizeof(uint16_t);
316   if (offset + sizeof(uint16_t) > max_size) {
317     MS_LOG(ERROR) << "offset over max size"
318                   << " offset:" << offset << " max_size:" << max_size;
319     return RET_ERROR;
320   }
321   *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
322   offset += sizeof(uint16_t);
323   int chunksc = bs->GetCurrChunkIndex() + sizeof(uint16_t);
324   if (offset + sizeof(uint32_t) > max_size) {
325     MS_LOG(ERROR) << "offset over max size"
326                   << " offset:" << offset << " max_size:" << max_size;
327     return RET_ERROR;
328   }
329   *(reinterpret_cast<uint32_t *>(&out8[offset])) = (uint32_t)chunksc;
330   offset += sizeof(uint32_t);
331   for (int j = 0; j < fse_quant.size; j++) {
332     if (offset + sizeof(uint32_t) > max_size) {
333       MS_LOG(ERROR) << "offset over max size"
334                     << " offset:" << offset << " max_size:" << max_size;
335       return RET_ERROR;
336     }
337     *(reinterpret_cast<uint32_t *>(&out8[offset])) = (uint32_t)fse_quant.frequency[j];
338     offset += sizeof(uint32_t);
339   }
340   while (offset % kAlignSize != 0) {
341     if (offset + sizeof(uint16_t) > max_size) {
342       MS_LOG(ERROR) << "offset over max size"
343                     << " offset:" << offset << " max_size:" << max_size;
344       return RET_ERROR;
345     }
346     *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
347     offset += sizeof(uint16_t);
348   }
349   for (int j = 0; j < fse_quant.size; j++) {
350     if (offset + sizeof(float) > max_size) {
351       MS_LOG(ERROR) << "offset over max size"
352                     << " offset:" << offset << " max_size:" << max_size;
353       return RET_ERROR;
354     }
355     *(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
356     offset += sizeof(float);
357   }
358   while (offset % kAlignSize != 0) {
359     if (offset + sizeof(uint16_t) > max_size) {
360       MS_LOG(ERROR) << "offset over max size"
361                     << " offset:" << offset << " max_size:" << max_size;
362       return RET_ERROR;
363     }
364     *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
365     offset += sizeof(uint16_t);
366   }
367   for (int j = 0; j < bs->GetCurrChunkIndex() + 1; j++) {
368     if (offset + sizeof(uint64_t) > max_size) {
369       MS_LOG(ERROR) << "offset over max size"
370                     << " offset:" << offset << " max_size:" << max_size;
371       return RET_ERROR;
372     }
373     *(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetChunks()[j];
374     offset += sizeof(uint64_t);
375   }
376   if (offset + sizeof(uint64_t) > max_size) {
377     MS_LOG(ERROR) << "offset over max size"
378                   << " offset:" << offset << " max_size:" << max_size;
379     return RET_ERROR;
380   }
381   *(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetCurrChunk();
382   offset += sizeof(uint64_t);
383   if (offset + sizeof(uint8_t) > max_size) {
384     MS_LOG(ERROR) << "offset over max size"
385                   << " offset:" << offset << " max_size:" << max_size;
386     return RET_ERROR;
387   }
388   *(reinterpret_cast<uint8_t *>(&out8[offset])) = (uint8_t)bs->GetCurrBitCount();
389   offset += sizeof(uint8_t);
390   if (static_cast<int>(offset) > static_cast<int>(tensor_input->data.size())) {
391     MS_LOG(ERROR) << "Too many symbol.";
392     return RET_ERROR;
393   }
394   *out_size = offset;
395   return RET_OK;
396 }
397 
SerializingToOut(schema::TensorT * tensor_input,BitStream * bs,const FSEQuant & fse_quant,int table_log)398 int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
399                                  int table_log) {
400   MSLITE_CHECK_PTR(tensor_input);
401   MSLITE_CHECK_PTR(bs);
402   const int extend_size = 2;
403   auto max_size = tensor_input->data.size() * extend_size;
404   auto *out8 = static_cast<uint8_t *>(malloc(max_size));
405   MSLITE_CHECK_PTR(out8);
406   size_t out_size = 0;
407   auto ret = SerializingToTensor(tensor_input, bs, fse_quant, table_log, out8, max_size, &out_size);
408   if (ret != RET_OK) {
409     MS_LOG(ERROR) << "Store data to tensor failed.";
410     free(out8);
411     return ret;
412   }
413   tensor_input->data.resize(out_size);
414   MSLITE_CHECK_PTR(tensor_input->data.data());
415   if (memcpy_s(tensor_input->data.data(), out_size, out8, out_size) != EOK) {
416     MS_LOG(ERROR) << "memcpy failed.";
417     free(out8);
418     return RET_ERROR;
419   }
420   tensor_input->quantParams.clear();
421   tensor_input->weightQunatCompressType = schema::WeightQunatCompressType_FSE;
422   tensor_input->dataType = TypeId::kNumberTypeFloat32;
423   free(out8);
424   return RET_OK;
425 }
426 }  // namespace mindspore::lite::quant
427