1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <memory>
18 #include <vector>
19 #include "tools/converter/quantizer/fse_decoder.h"
20 #include "include/errorcode.h"
21 #include "src/common/log_adapter.h"
22 #include "src/common/log_util.h"
23 #include "nnacl/op_base.h"
24
25 namespace mindspore::lite::quant {
26 namespace {
27 constexpr size_t kTableExtend = 3;
28 constexpr size_t kAlignOffset = 7;
29 constexpr size_t kThreeBytes = 3;
30 } // namespace
FSECreateStatesForDecoding(const uint32_t * symbol_frequency,int symbol_frequency_count,size_t table_log,uint16_t * new_state_baseline,uint8_t * bit_count,uint16_t * symbol_table)31 int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count,
32 size_t table_log, uint16_t *new_state_baseline, uint8_t *bit_count,
33 uint16_t *symbol_table) {
34 CHECK_NULL_RETURN(symbol_frequency);
35 CHECK_NULL_RETURN(new_state_baseline);
36 CHECK_NULL_RETURN(bit_count);
37 CHECK_NULL_RETURN(symbol_table);
38 const size_t table_size = 1u << table_log;
39 const size_t table_mask = table_size - 1;
40 size_t step = ((table_size >> 1) + (table_size >> kTableExtend) + kTableExtend);
41 size_t pos = 0;
42 for (int sym = 0; sym < symbol_frequency_count; sym++) {
43 for (uint32_t i = 0; i < symbol_frequency[sym]; i++) {
44 symbol_table[pos] = sym;
45 pos = (pos + step) & table_mask;
46 while (pos > table_mask) {
47 pos = (pos + step) & table_mask;
48 }
49 }
50 }
51 if (pos != 0) {
52 MS_LOG(ERROR) << "pos must equal 0.";
53 return RET_ERROR;
54 }
55 // defensive copy to not mutate frequency:
56 std::vector<uint32_t> frequency(symbol_frequency, symbol_frequency + symbol_frequency_count);
57
58 for (size_t i = 0; i < table_size; i++) {
59 uint16_t sym = symbol_table[i];
60 uint32_t x = frequency[sym];
61 frequency[sym] += 1;
62 MS_CHECK_GE(table_log, FSEBitStream::CountBits(x), RET_ERROR);
63 bit_count[i] = static_cast<uint8_t>(table_log - FSEBitStream::CountBits(x));
64 new_state_baseline[i] = (x << bit_count[i]) - table_size;
65 }
66 return RET_OK;
67 }
68
DecodeBuffer(int8_t * buffer,size_t data_size,FSEBuffer * fse_buffer)69 int FSEDecoder::DecodeBuffer(int8_t *buffer, size_t data_size, FSEBuffer *fse_buffer) {
70 CHECK_NULL_RETURN(buffer);
71 CHECK_NULL_RETURN(fse_buffer);
72 if (data_size < sizeof(uint16_t)) {
73 MS_LOG(ERROR) << "data_size is invalid.";
74 return RET_ERROR;
75 }
76 size_t i = 0;
77 // 16bit for frequency_count
78 fse_buffer->frequency_count = *(reinterpret_cast<uint16_t *>(buffer + i));
79 i += sizeof(uint16_t);
80 if (i > data_size) {
81 MS_LOG(ERROR) << "index over total size"
82 << " index:" << i << " total size:" << data_size;
83 return RET_ERROR;
84 }
85 // 16bit for table_log
86 fse_buffer->table_log = *(reinterpret_cast<uint16_t *>(buffer + i));
87 i += sizeof(uint16_t);
88 if (i > data_size) {
89 MS_LOG(ERROR) << "index over total size"
90 << " index:" << i << " total size:" << data_size;
91 return RET_ERROR;
92 }
93 // 32bit for ChunkCount
94 fse_buffer->chunk_count = *(reinterpret_cast<uint32_t *>(buffer + i));
95 const size_t offset = 2;
96 // 32bit for CurrChunkIndex
97 fse_buffer->curr_chunk_index = fse_buffer->chunk_count - offset;
98 i += sizeof(uint32_t);
99 if (i > data_size) {
100 MS_LOG(ERROR) << "index over total size"
101 << " index:" << i << " total size:" << data_size;
102 return RET_ERROR;
103 }
104 // 32bit * frequency_count for frequency
105 fse_buffer->frequency = reinterpret_cast<uint32_t *>(buffer + i);
106 i += fse_buffer->frequency_count * sizeof(uint32_t);
107 // Used for 8-byte(64bit) alignment
108 i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
109 if (i > data_size) {
110 MS_LOG(ERROR) << "index over total size"
111 << " index:" << i << " total size:" << data_size;
112 return RET_ERROR;
113 }
114 // 32bit * frequency_count for centroids
115 fse_buffer->centroids = reinterpret_cast<void *>(buffer + i);
116 fse_buffer->centroid_size = fse_buffer->frequency_count * sizeof(float);
117 i += fse_buffer->centroid_size;
118 // Used for 8-byte(64bit) alignment
119 i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
120 if (i > data_size) {
121 MS_LOG(ERROR) << "index over total size"
122 << " index:" << i << " total size:" << data_size;
123 return RET_ERROR;
124 }
125 // 64bit * bs_.GetCurrChunkIndex() + 1 for Chunks.
126 fse_buffer->chunks = reinterpret_cast<uint64_t *>(buffer + i);
127 fse_buffer->chunk_size = (fse_buffer->curr_chunk_index + 1) * sizeof(uint64_t);
128 i += fse_buffer->chunk_size;
129 if (i > data_size) {
130 MS_LOG(ERROR) << "index over total size"
131 << " index:" << i << " total size:" << data_size;
132 return RET_ERROR;
133 }
134 // 64bit for CurrChunk
135 fse_buffer->curr_chunk = *(reinterpret_cast<uint64_t *>(buffer + i));
136 i += sizeof(uint64_t);
137 if (i > data_size) {
138 MS_LOG(ERROR) << "index over total size"
139 << " index:" << i << " total size:" << data_size;
140 return RET_ERROR;
141 }
142 // 8bit for CurrBitCount
143 fse_buffer->curr_bit_count = *(reinterpret_cast<uint8_t *>(buffer + i));
144 i += sizeof(uint8_t);
145
146 if (i < data_size) { // There is more data after what was extracted
147 i += kThreeBytes * sizeof(uint8_t); // Align to 32 bit for ChunkEndsCount
148 if (i > data_size) {
149 MS_LOG(ERROR) << " index:" << i << " is over total size:" << data_size;
150 return RET_ERROR;
151 }
152 uint32_t chunk_ends_count = *(reinterpret_cast<uint32_t *>(buffer + i));
153 if ((i + sizeof(uint32_t) + chunk_ends_count * sizeof(uint64_t)) > data_size) {
154 MS_LOG(ERROR) << " index:" << i << " is over total size:" << data_size;
155 return RET_ERROR;
156 }
157 fse_buffer->chunk_ends_count = chunk_ends_count;
158 i += sizeof(uint32_t);
159 fse_buffer->chunk_ends = reinterpret_cast<uint64_t *>(buffer + i);
160 }
161 return RET_OK;
162 }
163
DeCompress(const SchemaTensorWrapper & src_tensor,Tensor * dst_tensor,schema::WeightQuantCompressType compress_type)164 int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor,
165 schema::WeightQuantCompressType compress_type) {
166 CHECK_NULL_RETURN(src_tensor.handler());
167 CHECK_NULL_RETURN(src_tensor.data());
168 CHECK_NULL_RETURN(dst_tensor);
169 if (dst_tensor->MutableData() == nullptr) {
170 MS_LOG(ERROR) << "tensor data is nullptr.";
171 return RET_ERROR;
172 }
173 auto total_size = src_tensor.length();
174 int out_sz = dst_tensor->ElementsNum();
175 MS_CHECK_GT(out_sz, 0, RET_ERROR);
176 // deserialize from `data`:
177 FSEBitStream bs;
178
179 size_t i = 0;
180 auto data8 = reinterpret_cast<int8_t *>(const_cast<void *>(src_tensor.data()));
181 CHECK_NULL_RETURN(data8);
182 // 16bit for frequency_count
183 uint16_t frequency_count = *(reinterpret_cast<uint16_t *>(&data8[i]));
184 i += sizeof(uint16_t);
185 if (i > total_size) {
186 MS_LOG(ERROR) << "index over total size"
187 << " index:" << i << " total size:" << total_size;
188 return RET_ERROR;
189 }
190 // 16bit for table_log
191 size_t table_log = *(reinterpret_cast<uint16_t *>(&data8[i]));
192 i += sizeof(uint16_t);
193 if (i > total_size) {
194 MS_LOG(ERROR) << "index over total size"
195 << " index:" << i << " total size:" << total_size;
196 return RET_ERROR;
197 }
198 // 32bit for ChunkCount
199 bs.SetChunkCount(*(reinterpret_cast<uint32_t *>(&data8[i])));
200 const int offset = 2;
201 // 32bit for CurrChunkIndex
202 bs.SetCurrChunkIndex(bs.GetChunkCount() - offset);
203 i += sizeof(uint32_t);
204 if (i > total_size) {
205 MS_LOG(ERROR) << "index over total size"
206 << " index:" << i << " total size:" << total_size;
207 return RET_ERROR;
208 }
209 // 32bit * frequency_count for frequency
210 auto *frequency = reinterpret_cast<uint32_t *>(&data8[i]);
211 i += frequency_count * sizeof(uint32_t);
212 // Used for 8-byte(64bit) alignment
213 i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
214 if (i > total_size) {
215 MS_LOG(ERROR) << "index over total size"
216 << " index:" << i << " total size:" << total_size;
217 return RET_ERROR;
218 }
219 // 32bit * frequency_count for centroids
220 auto centroids = reinterpret_cast<void *>(&data8[i]);
221 i += frequency_count * sizeof(float);
222 // Used for 8-byte(64bit) alignment
223 i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
224 if (i > total_size) {
225 MS_LOG(ERROR) << "index over total size"
226 << " index:" << i << " total size:" << total_size;
227 return RET_ERROR;
228 }
229 // 64bit * bs.GetCurrChunkIndex() + 1 for Chunks.
230 bs.SetChunks(reinterpret_cast<uint64_t *>(&data8[i]));
231 i += (bs.GetCurrChunkIndex() + 1) * sizeof(uint64_t);
232 if (i > total_size) {
233 MS_LOG(ERROR) << "index over total size"
234 << " index:" << i << " total size:" << total_size;
235 return RET_ERROR;
236 }
237 // 64bit for CurrChunk
238 bs.SetCurrChunk(*(reinterpret_cast<uint64_t *>(&data8[i])));
239 i += sizeof(uint64_t);
240 if (i > total_size) {
241 MS_LOG(ERROR) << "index over total size"
242 << " index:" << i << " total size:" << total_size;
243 return RET_ERROR;
244 }
245 // 8bit for CurrBitCount
246 bs.SetCurrBitCount(*(reinterpret_cast<uint8_t *>(&data8[i])));
247 int ret;
248 if (compress_type == schema::WeightQuantCompressType_FSE) {
249 ret = FSEDecode<float, float>(&bs, static_cast<float *>(dst_tensor->data()), out_sz, frequency, frequency_count,
250 static_cast<float *>(centroids), table_log);
251 } else { // WeightQuantCompressType_FSE_INT
252 if (src_tensor.handler()->dataType() == kNumberTypeInt8) {
253 ret = FSEDecode<int, int8_t>(&bs, static_cast<int8_t *>(dst_tensor->data()), out_sz, frequency, frequency_count,
254 static_cast<int *>(centroids), table_log);
255 } else { // kNumberTypeInt16
256 ret = FSEDecode<int, int16_t>(&bs, static_cast<int16_t *>(dst_tensor->data()), out_sz, frequency, frequency_count,
257 static_cast<int *>(centroids), table_log);
258 }
259 }
260 if (ret != RET_OK) {
261 MS_LOG(ERROR) << "FSE Decode failed.";
262 return RET_ERROR;
263 }
264 return RET_OK;
265 }
266 } // namespace mindspore::lite::quant
267