• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 #include <vector>
19 #include "tools/converter/quantizer/fse_decoder.h"
20 #include "include/errorcode.h"
21 #include "src/common/log_adapter.h"
22 #include "src/common/log_util.h"
23 #include "nnacl/op_base.h"
24 
25 namespace mindspore::lite::quant {
26 namespace {
27 constexpr size_t kTableExtend = 3;
28 constexpr size_t kAlignOffset = 7;
29 constexpr size_t kThreeBytes = 3;
30 }  // namespace
FSECreateStatesForDecoding(const uint32_t * symbol_frequency,int symbol_frequency_count,size_t table_log,uint16_t * new_state_baseline,uint8_t * bit_count,uint16_t * symbol_table)31 int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count,
32                                            size_t table_log, uint16_t *new_state_baseline, uint8_t *bit_count,
33                                            uint16_t *symbol_table) {
34   CHECK_NULL_RETURN(symbol_frequency);
35   CHECK_NULL_RETURN(new_state_baseline);
36   CHECK_NULL_RETURN(bit_count);
37   CHECK_NULL_RETURN(symbol_table);
38   const size_t table_size = 1u << table_log;
39   const size_t table_mask = table_size - 1;
40   size_t step = ((table_size >> 1) + (table_size >> kTableExtend) + kTableExtend);
41   size_t pos = 0;
42   for (int sym = 0; sym < symbol_frequency_count; sym++) {
43     for (uint32_t i = 0; i < symbol_frequency[sym]; i++) {
44       symbol_table[pos] = sym;
45       pos = (pos + step) & table_mask;
46       while (pos > table_mask) {
47         pos = (pos + step) & table_mask;
48       }
49     }
50   }
51   if (pos != 0) {
52     MS_LOG(ERROR) << "pos must equal 0.";
53     return RET_ERROR;
54   }
55   // defensive copy to not mutate frequency:
56   std::vector<uint32_t> frequency(symbol_frequency, symbol_frequency + symbol_frequency_count);
57 
58   for (size_t i = 0; i < table_size; i++) {
59     uint16_t sym = symbol_table[i];
60     uint32_t x = frequency[sym];
61     frequency[sym] += 1;
62     MS_CHECK_GE(table_log, FSEBitStream::CountBits(x), RET_ERROR);
63     bit_count[i] = static_cast<uint8_t>(table_log - FSEBitStream::CountBits(x));
64     new_state_baseline[i] = (x << bit_count[i]) - table_size;
65   }
66   return RET_OK;
67 }
68 
DecodeBuffer(int8_t * buffer,size_t data_size,FSEBuffer * fse_buffer)69 int FSEDecoder::DecodeBuffer(int8_t *buffer, size_t data_size, FSEBuffer *fse_buffer) {
70   CHECK_NULL_RETURN(buffer);
71   CHECK_NULL_RETURN(fse_buffer);
72   if (data_size < sizeof(uint16_t)) {
73     MS_LOG(ERROR) << "data_size is invalid.";
74     return RET_ERROR;
75   }
76   size_t i = 0;
77   // 16bit for frequency_count
78   fse_buffer->frequency_count = *(reinterpret_cast<uint16_t *>(buffer + i));
79   i += sizeof(uint16_t);
80   if (i > data_size) {
81     MS_LOG(ERROR) << "index over total size"
82                   << " index:" << i << " total size:" << data_size;
83     return RET_ERROR;
84   }
85   // 16bit for table_log
86   fse_buffer->table_log = *(reinterpret_cast<uint16_t *>(buffer + i));
87   i += sizeof(uint16_t);
88   if (i > data_size) {
89     MS_LOG(ERROR) << "index over total size"
90                   << " index:" << i << " total size:" << data_size;
91     return RET_ERROR;
92   }
93   // 32bit for ChunkCount
94   fse_buffer->chunk_count = *(reinterpret_cast<uint32_t *>(buffer + i));
95   const size_t offset = 2;
96   // 32bit for CurrChunkIndex
97   fse_buffer->curr_chunk_index = fse_buffer->chunk_count - offset;
98   i += sizeof(uint32_t);
99   if (i > data_size) {
100     MS_LOG(ERROR) << "index over total size"
101                   << " index:" << i << " total size:" << data_size;
102     return RET_ERROR;
103   }
104   // 32bit * frequency_count for frequency
105   fse_buffer->frequency = reinterpret_cast<uint32_t *>(buffer + i);
106   i += fse_buffer->frequency_count * sizeof(uint32_t);
107   // Used for 8-byte(64bit) alignment
108   i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
109   if (i > data_size) {
110     MS_LOG(ERROR) << "index over total size"
111                   << " index:" << i << " total size:" << data_size;
112     return RET_ERROR;
113   }
114   // 32bit * frequency_count for centroids
115   fse_buffer->centroids = reinterpret_cast<void *>(buffer + i);
116   fse_buffer->centroid_size = fse_buffer->frequency_count * sizeof(float);
117   i += fse_buffer->centroid_size;
118   // Used for 8-byte(64bit) alignment
119   i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
120   if (i > data_size) {
121     MS_LOG(ERROR) << "index over total size"
122                   << " index:" << i << " total size:" << data_size;
123     return RET_ERROR;
124   }
125   // 64bit * bs_.GetCurrChunkIndex() + 1 for Chunks.
126   fse_buffer->chunks = reinterpret_cast<uint64_t *>(buffer + i);
127   fse_buffer->chunk_size = (fse_buffer->curr_chunk_index + 1) * sizeof(uint64_t);
128   i += fse_buffer->chunk_size;
129   if (i > data_size) {
130     MS_LOG(ERROR) << "index over total size"
131                   << " index:" << i << " total size:" << data_size;
132     return RET_ERROR;
133   }
134   // 64bit for CurrChunk
135   fse_buffer->curr_chunk = *(reinterpret_cast<uint64_t *>(buffer + i));
136   i += sizeof(uint64_t);
137   if (i > data_size) {
138     MS_LOG(ERROR) << "index over total size"
139                   << " index:" << i << " total size:" << data_size;
140     return RET_ERROR;
141   }
142   // 8bit for CurrBitCount
143   fse_buffer->curr_bit_count = *(reinterpret_cast<uint8_t *>(buffer + i));
144   i += sizeof(uint8_t);
145 
146   if (i < data_size) {                   // There is more data after what was extracted
147     i += kThreeBytes * sizeof(uint8_t);  // Align to 32 bit for ChunkEndsCount
148     if (i > data_size) {
149       MS_LOG(ERROR) << " index:" << i << " is over total size:" << data_size;
150       return RET_ERROR;
151     }
152     uint32_t chunk_ends_count = *(reinterpret_cast<uint32_t *>(buffer + i));
153     if ((i + sizeof(uint32_t) + chunk_ends_count * sizeof(uint64_t)) > data_size) {
154       MS_LOG(ERROR) << " index:" << i << " is over total size:" << data_size;
155       return RET_ERROR;
156     }
157     fse_buffer->chunk_ends_count = chunk_ends_count;
158     i += sizeof(uint32_t);
159     fse_buffer->chunk_ends = reinterpret_cast<uint64_t *>(buffer + i);
160   }
161   return RET_OK;
162 }
163 
DeCompress(const SchemaTensorWrapper & src_tensor,Tensor * dst_tensor,schema::WeightQuantCompressType compress_type)164 int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor,
165                            schema::WeightQuantCompressType compress_type) {
166   CHECK_NULL_RETURN(src_tensor.handler());
167   CHECK_NULL_RETURN(src_tensor.data());
168   CHECK_NULL_RETURN(dst_tensor);
169   if (dst_tensor->MutableData() == nullptr) {
170     MS_LOG(ERROR) << "tensor data is nullptr.";
171     return RET_ERROR;
172   }
173   auto total_size = src_tensor.length();
174   int out_sz = dst_tensor->ElementsNum();
175   MS_CHECK_GT(out_sz, 0, RET_ERROR);
176   // deserialize from `data`:
177   FSEBitStream bs;
178 
179   size_t i = 0;
180   auto data8 = reinterpret_cast<int8_t *>(const_cast<void *>(src_tensor.data()));
181   CHECK_NULL_RETURN(data8);
182   // 16bit for frequency_count
183   uint16_t frequency_count = *(reinterpret_cast<uint16_t *>(&data8[i]));
184   i += sizeof(uint16_t);
185   if (i > total_size) {
186     MS_LOG(ERROR) << "index over total size"
187                   << " index:" << i << " total size:" << total_size;
188     return RET_ERROR;
189   }
190   // 16bit for table_log
191   size_t table_log = *(reinterpret_cast<uint16_t *>(&data8[i]));
192   i += sizeof(uint16_t);
193   if (i > total_size) {
194     MS_LOG(ERROR) << "index over total size"
195                   << " index:" << i << " total size:" << total_size;
196     return RET_ERROR;
197   }
198   // 32bit for ChunkCount
199   bs.SetChunkCount(*(reinterpret_cast<uint32_t *>(&data8[i])));
200   const int offset = 2;
201   // 32bit for CurrChunkIndex
202   bs.SetCurrChunkIndex(bs.GetChunkCount() - offset);
203   i += sizeof(uint32_t);
204   if (i > total_size) {
205     MS_LOG(ERROR) << "index over total size"
206                   << " index:" << i << " total size:" << total_size;
207     return RET_ERROR;
208   }
209   // 32bit * frequency_count for frequency
210   auto *frequency = reinterpret_cast<uint32_t *>(&data8[i]);
211   i += frequency_count * sizeof(uint32_t);
212   // Used for 8-byte(64bit) alignment
213   i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
214   if (i > total_size) {
215     MS_LOG(ERROR) << "index over total size"
216                   << " index:" << i << " total size:" << total_size;
217     return RET_ERROR;
218   }
219   // 32bit * frequency_count for centroids
220   auto centroids = reinterpret_cast<void *>(&data8[i]);
221   i += frequency_count * sizeof(float);
222   // Used for 8-byte(64bit) alignment
223   i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
224   if (i > total_size) {
225     MS_LOG(ERROR) << "index over total size"
226                   << " index:" << i << " total size:" << total_size;
227     return RET_ERROR;
228   }
229   // 64bit * bs.GetCurrChunkIndex() + 1 for Chunks.
230   bs.SetChunks(reinterpret_cast<uint64_t *>(&data8[i]));
231   i += (bs.GetCurrChunkIndex() + 1) * sizeof(uint64_t);
232   if (i > total_size) {
233     MS_LOG(ERROR) << "index over total size"
234                   << " index:" << i << " total size:" << total_size;
235     return RET_ERROR;
236   }
237   // 64bit for CurrChunk
238   bs.SetCurrChunk(*(reinterpret_cast<uint64_t *>(&data8[i])));
239   i += sizeof(uint64_t);
240   if (i > total_size) {
241     MS_LOG(ERROR) << "index over total size"
242                   << " index:" << i << " total size:" << total_size;
243     return RET_ERROR;
244   }
245   // 8bit for CurrBitCount
246   bs.SetCurrBitCount(*(reinterpret_cast<uint8_t *>(&data8[i])));
247   int ret;
248   if (compress_type == schema::WeightQuantCompressType_FSE) {
249     ret = FSEDecode<float, float>(&bs, static_cast<float *>(dst_tensor->data()), out_sz, frequency, frequency_count,
250                                   static_cast<float *>(centroids), table_log);
251   } else {  // WeightQuantCompressType_FSE_INT
252     if (src_tensor.handler()->dataType() == kNumberTypeInt8) {
253       ret = FSEDecode<int, int8_t>(&bs, static_cast<int8_t *>(dst_tensor->data()), out_sz, frequency, frequency_count,
254                                    static_cast<int *>(centroids), table_log);
255     } else {  // kNumberTypeInt16
256       ret = FSEDecode<int, int16_t>(&bs, static_cast<int16_t *>(dst_tensor->data()), out_sz, frequency, frequency_count,
257                                     static_cast<int *>(centroids), table_log);
258     }
259   }
260   if (ret != RET_OK) {
261     MS_LOG(ERROR) << "FSE Decode failed.";
262     return RET_ERROR;
263   }
264   return RET_OK;
265 }
266 }  // namespace mindspore::lite::quant
267