1 // Copyright 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/decoder/quantization.h"
16 #include "src/base/math_utils.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <cassert>
21 #include <map>
22 #include <memory>
23 #include <vector>
24
25 namespace astc_codec {
26
27 namespace {
28
29 // Trit unquantization procedure as described in Section C.2.13
GetUnquantizedTritValue(int trit,int bits,int range)30 int GetUnquantizedTritValue(int trit, int bits, int range) {
31 int a = (bits & 1) ? 0x1FF : 0;
32 int b = 0, c = 0;
33 switch (range) {
34 case 5: {
35 b = 0;
36 c = 204;
37 }
38 break;
39
40 case 11: {
41 int x = (bits >> 1) & 0x1;
42 b = (x << 1) | (x << 2) | (x << 4) | (x << 8);
43 c = 93;
44 }
45 break;
46
47 case 23: {
48 int x = (bits >> 1) & 0x3;
49 b = x | (x << 2) | (x << 7);
50 c = 44;
51 }
52 break;
53
54 case 47: {
55 int x = (bits >> 1) & 0x7;
56 b = x | (x << 6);
57 c = 22;
58 }
59 break;
60
61 case 95: {
62 int x = (bits >> 1) & 0xF;
63 b = (x >> 2) | (x << 5);
64 c = 11;
65 }
66 break;
67
68 case 191: {
69 int x = (bits >> 1) & 0x1F;
70 b = (x >> 4) | (x << 4);
71 c = 5;
72 }
73 break;
74
75 default:
76 assert(false && "Illegal trit encoding");
77 break;
78 }
79
80 int t = trit * c + b;
81 t ^= a;
82 t = (a & 0x80) | (t >> 2);
83 return t;
84 }
85
86 // Quint unquantization procedure as described in Section C.2.13
GetUnquantizedQuintValue(int quint,int bits,int range)87 int GetUnquantizedQuintValue(int quint, int bits, int range) {
88 int a = (bits & 1) ? 0x1FF : 0;
89 int b = 0, c = 0;
90 switch (range) {
91 case 9: {
92 b = 0;
93 c = 113;
94 }
95 break;
96
97 case 19: {
98 int x = (bits >> 1) & 0x1;
99 b = (x << 2) | (x << 3) | (x << 8);
100 c = 54;
101 }
102 break;
103
104 case 39: {
105 int x = (bits >> 1) & 0x3;
106 b = (x >> 1) | (x << 1) | (x << 7);
107 c = 26;
108 }
109 break;
110
111 case 79: {
112 int x = (bits >> 1) & 0x7;
113 b = (x >> 1) | (x << 6);
114 c = 13;
115 }
116 break;
117
118 case 159: {
119 int x = (bits >> 1) & 0xF;
120 b = (x >> 3) | (x << 5);
121 c = 6;
122 }
123 break;
124
125 default:
126 assert(false && "Illegal quint encoding");
127 break;
128 }
129
130 int t = quint * c + b;
131 t ^= a;
132 t = (a & 0x80) | (t >> 2);
133 return t;
134 }
135
136 // Trit unquantization procedure as described in Section C.2.17. In the code
137 // below, the variables a, b, and c correspond to the columns A, B, and C in
138 // the specification.
GetUnquantizedTritWeight(int trit,int bits,int range)139 int GetUnquantizedTritWeight(int trit, int bits, int range) {
140 int a = (bits & 1) ? 0x7F : 0;
141 int b = 0, c = 0;
142 switch (range) {
143 case 2:
144 return (std::array<int, 3> {{ 0, 32, 63 }})[trit];
145
146 case 5:
147 c = 50;
148 b = 0;
149 break;
150
151 case 11: {
152 c = 23;
153 b = (bits >> 1) & 1;
154 b |= (b << 2) | (b << 6);
155 }
156 break;
157
158 case 23: {
159 c = 11;
160 b = (bits >> 1) & 0x3;
161 b |= (b << 5);
162 }
163 break;
164
165 default:
166 assert(false && "Illegal trit encoding");
167 break;
168 }
169
170 int t = trit * c + b;
171 t ^= a;
172 t = (a & 0x20) | (t >> 2);
173 return t;
174 }
175
176 // Quint unquantization procedure as described in Section C.2.17. In the code
177 // below, the variables a, b, and c correspond to the columns A, B, and C in
178 // the specification.
GetUnquantizedQuintWeight(int quint,int bits,int range)179 int GetUnquantizedQuintWeight(int quint, int bits, int range) {
180 int a = (bits & 1) ? 0x7F : 0;
181 int b = 0, c = 0;
182 switch (range) {
183 case 4:
184 return (std::array<int, 5> {{ 0, 16, 32, 47, 63 }})[quint];
185
186 case 9:
187 c = 28;
188 b = 0;
189 break;
190
191 case 19: {
192 c = 13;
193 b = (bits >> 1) & 0x1;
194 b = (b << 1) | (b << 6);
195 }
196 break;
197
198 default:
199 assert(false && "Illegal quint encoding");
200 break;
201 }
202
203 int t = quint * c + b;
204 t ^= a;
205 t = (a & 0x20) | (t >> 2);
206 return t;
207 }
208
209 // A Quantization map allows us to convert to/from values that are quantized
210 // according to the ASTC spec.
211 class QuantizationMap {
212 public:
Quantize(int x) const213 int Quantize(int x) const {
214 return x < quantization_map_.size() ? quantization_map_.at(x) : 0;
215 }
216
Unquantize(int x) const217 int Unquantize(int x) const {
218 return x < unquantization_map_.size() ? unquantization_map_.at(x) : 0;
219 }
220
221 protected:
QuantizationMap()222 QuantizationMap() { }
223 std::vector<int> quantization_map_;
224 std::vector<int> unquantization_map_;
225
GenerateQuantizationMap()226 void GenerateQuantizationMap() {
227 assert(unquantization_map_.size() > 1);
228 quantization_map_.clear();
229
230 // TODO(google) For weights, we don't need quantization values all the
231 // way up to 256, but it doesn't hurt -- just wastes memory, but the code
232 // is much cleaner this way
233 for (int i = 0; i < 256; ++i) {
234 int best_idx = 0;
235 int best_idx_score = 256;
236 int idx = 0;
237 for (int unquantized_val : unquantization_map_) {
238 const int diff = i - unquantized_val;
239 const int idx_score = diff * diff;
240 if (idx_score < best_idx_score) {
241 best_idx = idx;
242 best_idx_score = idx_score;
243 }
244 idx++;
245 }
246
247 quantization_map_.push_back(best_idx);
248 }
249 }
250 };
251
252 template<int (*UnquantizationFunc)(int, int, int)>
253 class TritQuantizationMap : public QuantizationMap {
254 public:
TritQuantizationMap(int range)255 explicit TritQuantizationMap(int range) : QuantizationMap() {
256 assert((range + 1) % 3 == 0);
257 const int num_bits_pow_2 = (range + 1) / 3;
258 const int num_bits =
259 num_bits_pow_2 == 0 ? 0 : base::Log2Floor(num_bits_pow_2);
260
261 for (int trit = 0; trit < 3; ++trit) {
262 for (int bits = 0; bits < (1 << num_bits); ++bits) {
263 unquantization_map_.push_back(UnquantizationFunc(trit, bits, range));
264 }
265 }
266
267 GenerateQuantizationMap();
268 }
269 };
270
271 template<int (*UnquantizationFunc)(int, int, int)>
272 class QuintQuantizationMap : public QuantizationMap {
273 public:
QuintQuantizationMap(int range)274 explicit QuintQuantizationMap(int range) : QuantizationMap() {
275 assert((range + 1) % 5 == 0);
276 const int num_bits_pow_2 = (range + 1) / 5;
277 const int num_bits =
278 num_bits_pow_2 == 0 ? 0 : base::Log2Floor(num_bits_pow_2);
279
280 for (int quint = 0; quint < 5; ++quint) {
281 for (int bits = 0; bits < (1 << num_bits); ++bits) {
282 unquantization_map_.push_back(UnquantizationFunc(quint, bits, range));
283 }
284 }
285
286 GenerateQuantizationMap();
287 }
288 };
289
290 template<int TotalUnquantizedBits>
291 class BitQuantizationMap : public QuantizationMap {
292 public:
BitQuantizationMap(int range)293 explicit BitQuantizationMap<TotalUnquantizedBits>(int range)
294 : QuantizationMap() {
295 // Make sure that if we're using bits then we have a positive power of two.
296 assert(base::CountOnes(range + 1) == 1);
297
298 const int num_bits = base::Log2Floor(range + 1);
299 for (int bits = 0; bits <= range; ++bits) {
300 // Need to replicate bits until we fill up the bits
301 int unquantized = bits;
302 int num_unquantized_bits = num_bits;
303 while (num_unquantized_bits < TotalUnquantizedBits) {
304 const int num_dst_bits_to_shift_up =
305 std::min(num_bits, TotalUnquantizedBits - num_unquantized_bits);
306 const int num_src_bits_to_shift_down =
307 num_bits - num_dst_bits_to_shift_up;
308 unquantized <<= num_dst_bits_to_shift_up;
309 unquantized |= bits >> num_src_bits_to_shift_down;
310 num_unquantized_bits += num_dst_bits_to_shift_up;
311 }
312 assert(num_unquantized_bits == TotalUnquantizedBits);
313
314 unquantization_map_.push_back(unquantized);
315
316 // Fill half of the quantization map with the previous value for bits
317 // and the other half with the current value for bits
318 if (bits > 0) {
319 const int prev_unquant = unquantization_map_.at(bits - 1);
320 while (quantization_map_.size() <= (prev_unquant + unquantized) / 2) {
321 quantization_map_.push_back(bits - 1);
322 }
323 }
324 while (quantization_map_.size() <= unquantized) {
325 quantization_map_.push_back(bits);
326 }
327 }
328
329 assert(quantization_map_.size() == 1 << TotalUnquantizedBits);
330 }
331 };
332
333 using QMap = std::shared_ptr<QuantizationMap>;
334
335 // Returns the quantization map for quantizing color values in [0, 255] with the
336 // smallest range that can accommodate |r|
GetQuantMapForValueRange(int r)337 static const QuantizationMap* GetQuantMapForValueRange(int r) {
338 // Endpoint values can be quantized using bits, trits, or quints. Here we
339 // store the quantization maps for each of the ranges that are supported by
340 // such an encoding. That way we can choose the proper quantization procedure
341 // based on the range of values rather than by having complicated switches and
342 // logic. We must use a std::map here instead of a std::unordered_map because
343 // of the assumption made in std::upper_bound about the iterators being from a
344 // poset.
345 static const auto* const kASTCEndpointQuantization = new std::map<int, QMap> {
346 { 5, QMap(new TritQuantizationMap<GetUnquantizedTritValue>(5)) },
347 { 7, QMap(new BitQuantizationMap<8>(7)) },
348 { 9, QMap(new QuintQuantizationMap<GetUnquantizedQuintValue>(9)) },
349 { 11, QMap(new TritQuantizationMap<GetUnquantizedTritValue>(11)) },
350 { 15, QMap(new BitQuantizationMap<8>(15)) },
351 { 19, QMap(new QuintQuantizationMap<GetUnquantizedQuintValue>(19)) },
352 { 23, QMap(new TritQuantizationMap<GetUnquantizedTritValue>(23)) },
353 { 31, QMap(new BitQuantizationMap<8>(31)) },
354 { 39, QMap(new QuintQuantizationMap<GetUnquantizedQuintValue>(39)) },
355 { 47, QMap(new TritQuantizationMap<GetUnquantizedTritValue>(47)) },
356 { 63, QMap(new BitQuantizationMap<8>(63)) },
357 { 79, QMap(new QuintQuantizationMap<GetUnquantizedQuintValue>(79)) },
358 { 95, QMap(new TritQuantizationMap<GetUnquantizedTritValue>(95)) },
359 { 127, QMap(new BitQuantizationMap<8>(127)) },
360 { 159, QMap(new QuintQuantizationMap<GetUnquantizedQuintValue>(159)) },
361 { 191, QMap(new TritQuantizationMap<GetUnquantizedTritValue>(191)) },
362 { 255, QMap(new BitQuantizationMap<8>(255)) },
363 };
364
365 assert(r < 256);
366 auto itr = kASTCEndpointQuantization->upper_bound(r);
367 if (itr != kASTCEndpointQuantization->begin()) {
368 return (--itr)->second.get();
369 }
370 return nullptr;
371 }
372
373 // Returns the quantization map for weight values in [0, 63] with the smallest
374 // range that can accommodate |r|
GetQuantMapForWeightRange(int r)375 static const QuantizationMap* GetQuantMapForWeightRange(int r) {
376 // Similar to endpoint quantization, weights can also be stored using trits,
377 // quints, or bits. Here we store the quantization maps for each of the ranges
378 // that are supported by such an encoding.
379 static const auto* const kASTCWeightQuantization = new std::map<int, QMap> {
380 { 1, QMap(new BitQuantizationMap<6>(1)) },
381 { 2, QMap(new TritQuantizationMap<GetUnquantizedTritWeight>(2)) },
382 { 3, QMap(new BitQuantizationMap<6>(3)) },
383 { 4, QMap(new QuintQuantizationMap<GetUnquantizedQuintWeight>(4)) },
384 { 5, QMap(new TritQuantizationMap<GetUnquantizedTritWeight>(5)) },
385 { 7, QMap(new BitQuantizationMap<6>(7)) },
386 { 9, QMap(new QuintQuantizationMap<GetUnquantizedQuintWeight>(9)) },
387 { 11, QMap(new TritQuantizationMap<GetUnquantizedTritWeight>(11)) },
388 { 15, QMap(new BitQuantizationMap<6>(15)) },
389 { 19, QMap(new QuintQuantizationMap<GetUnquantizedQuintWeight>(19)) },
390 { 23, QMap(new TritQuantizationMap<GetUnquantizedTritWeight>(23)) },
391 { 31, QMap(new BitQuantizationMap<6>(31)) },
392 };
393
394 assert(r < 32);
395 auto itr = kASTCWeightQuantization->upper_bound(r);
396 if (itr != kASTCWeightQuantization->begin()) {
397 return (--itr)->second.get();
398 }
399 return nullptr;
400 }
401
402 } // namespace
403
404 ////////////////////////////////////////////////////////////////////////////////
405
QuantizeCEValueToRange(int value,int range_max_value)406 int QuantizeCEValueToRange(int value, int range_max_value) {
407 assert(range_max_value >= kEndpointRangeMinValue);
408 assert(range_max_value <= 255);
409 assert(value >= 0);
410 assert(value <= 255);
411
412 const QuantizationMap* map = GetQuantMapForValueRange(range_max_value);
413 return map ? map->Quantize(value) : 0;
414 }
415
UnquantizeCEValueFromRange(int value,int range_max_value)416 int UnquantizeCEValueFromRange(int value, int range_max_value) {
417 assert(range_max_value >= kEndpointRangeMinValue);
418 assert(range_max_value <= 255);
419 assert(value >= 0);
420 assert(value <= range_max_value);
421
422 const QuantizationMap* map = GetQuantMapForValueRange(range_max_value);
423 return map ? map->Unquantize(value) : 0;
424 }
425
QuantizeWeightToRange(int weight,int range_max_value)426 int QuantizeWeightToRange(int weight, int range_max_value) {
427 assert(range_max_value >= 1);
428 assert(range_max_value <= kWeightRangeMaxValue);
429 assert(weight >= 0);
430 assert(weight <= 64);
431
432 // The quantization maps that define weight unquantization expect values in
433 // the range [0, 64), but the specification quantizes them to the range
434 // [0, 64] according to C.2.17. This is a slight hack similar to the one in
435 // the unquantization procedure to return the passed in unquantized value to
436 // [0, 64) prior to running it through the quantization procedure.
437 if (weight > 33) {
438 weight -= 1;
439 }
440 const QuantizationMap* map = GetQuantMapForWeightRange(range_max_value);
441 return map ? map->Quantize(weight) : 0;
442 }
443
UnquantizeWeightFromRange(int weight,int range_max_value)444 int UnquantizeWeightFromRange(int weight, int range_max_value) {
445 assert(range_max_value >= 1);
446 assert(range_max_value <= kWeightRangeMaxValue);
447 assert(weight >= 0);
448 assert(weight <= range_max_value);
449 const QuantizationMap* map = GetQuantMapForWeightRange(range_max_value);
450 int dq = map ? map->Unquantize(weight) : 0;
451
452 // Quantized weights are returned in the range [0, 64), but they should be
453 // returned in the range [0, 64], so according to C.2.17 we need to add one
454 // to the result.
455 assert(dq < 64);
456 if (dq > 32) {
457 dq += 1;
458 }
459 return dq;
460 }
461
462 } // namespace astc_codec
463