1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/utils/entropy_decoder.h"
16
17 #include <cassert>
18 #include <cstring>
19
20 #include "src/utils/common.h"
21 #include "src/utils/compiler_attributes.h"
22 #include "src/utils/constants.h"
23 #include "src/utils/cpu.h"
24
25 #if defined(__ARM_NEON__) || defined(__aarch64__) || \
26 (defined(_MSC_VER) && defined(_M_ARM))
27 #define LIBGAV1_ENTROPY_DECODER_ENABLE_NEON 1
28 #else
29 #define LIBGAV1_ENTROPY_DECODER_ENABLE_NEON 0
30 #endif
31
32 #if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
33 #include <arm_neon.h>
34 #endif
35
36 #if defined(__SSE2__) || defined(LIBGAV1_X86_MSVC)
37 #define LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2 1
38 #else
39 #define LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2 0
40 #endif
41
42 #if LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
43 #include <emmintrin.h>
44 #endif
45
46 namespace libgav1 {
47 namespace {
48
49 constexpr uint32_t kReadBitMask = ~255;
50 constexpr int kCdfPrecision = 6;
51 constexpr int kMinimumProbabilityPerSymbol = 4;
52
53 // This function computes the "cur" variable as specified inside the do-while
54 // loop in Section 8.2.6 of the spec. This function is monotonically
55 // decreasing as the values of index increases (note that the |cdf| array is
56 // sorted in decreasing order).
ScaleCdf(uint32_t values_in_range_shifted,const uint16_t * const cdf,int index,int symbol_count)57 uint32_t ScaleCdf(uint32_t values_in_range_shifted, const uint16_t* const cdf,
58 int index, int symbol_count) {
59 return ((values_in_range_shifted * (cdf[index] >> kCdfPrecision)) >> 1) +
60 (kMinimumProbabilityPerSymbol * (symbol_count - index));
61 }
62
UpdateCdf(uint16_t * const cdf,const int symbol_count,const int symbol)63 void UpdateCdf(uint16_t* const cdf, const int symbol_count, const int symbol) {
64 const uint16_t count = cdf[symbol_count];
65 // rate is computed in the spec as:
66 // 3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
67 // In this case cdf[N] is |count|.
68 // Min(FloorLog2(N), 2) is 1 for symbol_count == {2, 3} and 2 for all
69 // symbol_count > 3. So the equation becomes:
70 // 4 + (count > 15) + (count > 31) + (symbol_count > 3).
71 // Note that the largest value for count is 32 (it is not incremented beyond
72 // 32). So using that information:
73 // count >> 4 is 0 for count from 0 to 15.
74 // count >> 4 is 1 for count from 16 to 31.
75 // count >> 4 is 2 for count == 31.
76 // Now, the equation becomes:
77 // 4 + (count >> 4) + (symbol_count > 3).
78 // Since (count >> 4) can only be 0 or 1 or 2, the addition could be replaced
79 // with bitwise or:
80 // (4 | (count >> 4)) + (symbol_count > 3).
81 // but using addition will allow the compiler to eliminate an operation when
82 // symbol_count is known and this function is inlined.
83 const int rate = (count >> 4) + 4 + static_cast<int>(symbol_count > 3);
84 // Hints for further optimizations:
85 //
86 // 1. clang can vectorize this for loop with width 4, even though the loop
87 // contains an if-else statement. Therefore, it may be advantageous to use
88 // "i < symbol_count" as the loop condition when symbol_count is 8, 12, or 16
89 // (a multiple of 4 that's not too small).
90 //
91 // 2. The for loop can be rewritten in the following form, which would enable
92 // clang to vectorize the loop with width 8:
93 //
94 // const int rounding = (1 << rate) - 1;
95 // for (int i = 0; i < symbol_count - 1; ++i) {
96 // const uint16_t a = (i < symbol) ? kCdfMaxProbability : rounding;
97 // cdf[i] += static_cast<int16_t>(a - cdf[i]) >> rate;
98 // }
99 //
100 // The subtraction (a - cdf[i]) relies on the overflow semantics of unsigned
101 // integer arithmetic. The result of the unsigned subtraction is cast to a
102 // signed integer and right-shifted. This requires the right shift of a
103 // signed integer be an arithmetic shift, which is true for clang, gcc, and
104 // Visual C++.
105 assert(symbol_count - 1 > 0);
106 int i = 0;
107 do {
108 if (i < symbol) {
109 cdf[i] += (kCdfMaxProbability - cdf[i]) >> rate;
110 } else {
111 cdf[i] -= cdf[i] >> rate;
112 }
113 } while (++i < symbol_count - 1);
114 cdf[symbol_count] += static_cast<uint16_t>(count < 32);
115 }
116
117 // Define the UpdateCdfN functions. UpdateCdfN is a specialized implementation
118 // of UpdateCdf based on the fact that symbol_count == N. UpdateCdfN uses the
119 // SIMD instruction sets if available.
120
121 #if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
122
123 // The UpdateCdf() method contains the following for loop:
124 //
125 // for (int i = 0; i < symbol_count - 1; ++i) {
126 // if (i < symbol) {
127 // cdf[i] += (kCdfMaxProbability - cdf[i]) >> rate;
128 // } else {
129 // cdf[i] -= cdf[i] >> rate;
130 // }
131 // }
132 //
133 // It can be rewritten in the following two forms, which are amenable to SIMD
134 // implementations:
135 //
136 // const int rounding = (1 << rate) - 1;
137 // for (int i = 0; i < symbol_count - 1; ++i) {
138 // const uint16_t a = (i < symbol) ? kCdfMaxProbability : rounding;
139 // cdf[i] += static_cast<int16_t>(a - cdf[i]) >> rate;
140 // }
141 //
142 // or:
143 //
144 // const int rounding = (1 << rate) - 1;
145 // for (int i = 0; i < symbol_count - 1; ++i) {
146 // const uint16_t a = (i < symbol) ? (kCdfMaxProbability - rounding) : 0;
147 // cdf[i] -= static_cast<int16_t>(cdf[i] - a) >> rate;
148 // }
149 //
150 // The following ARM NEON implementations use a modified version of the first
151 // form, using the comparison mask and unsigned rollover to avoid the need to
152 // calculate rounding.
153 //
154 // The cdf array has symbol_count + 1 elements. The first symbol_count elements
155 // are the CDF. The last element is a count that is initialized to 0 and may
156 // grow up to 32. The for loop in UpdateCdf updates the CDF in the array. Since
157 // cdf[symbol_count - 1] is always 0, the for loop does not update
158 // cdf[symbol_count - 1]. However, it would be correct to have the for loop
159 // update cdf[symbol_count - 1] anyway: since symbol_count - 1 >= symbol, the
160 // for loop would take the else branch when i is symbol_count - 1:
161 // cdf[i] -= cdf[i] >> rate;
162 // Since cdf[symbol_count - 1] is 0, cdf[symbol_count - 1] would still be 0
163 // after the update. The ARM NEON implementations take advantage of this in the
164 // following two cases:
165 // 1. When symbol_count is 8 or 16, the vectorized code updates the first
166 // symbol_count elements in the array.
167 // 2. When symbol_count is 7, the vectorized code updates all the 8 elements in
168 // the cdf array. Since an invalid CDF value is written into cdf[7], the
169 // count in cdf[7] needs to be fixed up after the vectorized code.
170
UpdateCdf5(uint16_t * const cdf,const int symbol)171 void UpdateCdf5(uint16_t* const cdf, const int symbol) {
172 uint16x4_t cdf_vec = vld1_u16(cdf);
173 const uint16_t count = cdf[5];
174 const int rate = (count >> 4) + 5;
175 const uint16x4_t cdf_max_probability = vdup_n_u16(kCdfMaxProbability);
176 const uint16x4_t index = vcreate_u16(0x0003000200010000);
177 const uint16x4_t symbol_vec = vdup_n_u16(symbol);
178 const uint16x4_t mask = vcge_u16(index, symbol_vec);
179 // i < symbol: 32768, i >= symbol: 65535.
180 const uint16x4_t a = vorr_u16(mask, cdf_max_probability);
181 // i < symbol: 32768 - cdf, i >= symbol: 65535 - cdf.
182 const int16x4_t diff = vreinterpret_s16_u16(vsub_u16(a, cdf_vec));
183 // i < symbol: cdf - 0, i >= symbol: cdf - 65535.
184 const uint16x4_t cdf_offset = vsub_u16(cdf_vec, mask);
185 const int16x4_t negative_rate = vdup_n_s16(-rate);
186 // i < symbol: (32768 - cdf) >> rate, i >= symbol: (65535 (-1) - cdf) >> rate.
187 const uint16x4_t delta = vreinterpret_u16_s16(vshl_s16(diff, negative_rate));
188 // i < symbol: (cdf - 0) + ((32768 - cdf) >> rate).
189 // i >= symbol: (cdf - 65535) + ((65535 - cdf) >> rate).
190 cdf_vec = vadd_u16(cdf_offset, delta);
191 vst1_u16(cdf, cdf_vec);
192 cdf[5] = count + static_cast<uint16_t>(count < 32);
193 }
194
195 // This version works for |symbol_count| = 7, 8, or 9.
196 // See UpdateCdf5 for implementation details.
197 template <int symbol_count>
UpdateCdf7To9(uint16_t * const cdf,const int symbol)198 void UpdateCdf7To9(uint16_t* const cdf, const int symbol) {
199 static_assert(symbol_count >= 7 && symbol_count <= 9, "");
200 uint16x8_t cdf_vec = vld1q_u16(cdf);
201 const uint16_t count = cdf[symbol_count];
202 const int rate = (count >> 4) + 5;
203 const uint16x8_t cdf_max_probability = vdupq_n_u16(kCdfMaxProbability);
204 const uint16x8_t index = vcombine_u16(vcreate_u16(0x0003000200010000),
205 vcreate_u16(0x0007000600050004));
206 const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
207 const uint16x8_t mask = vcgeq_u16(index, symbol_vec);
208 const uint16x8_t a = vorrq_u16(mask, cdf_max_probability);
209 const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec));
210 const uint16x8_t cdf_offset = vsubq_u16(cdf_vec, mask);
211 const int16x8_t negative_rate = vdupq_n_s16(-rate);
212 const uint16x8_t delta =
213 vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
214 cdf_vec = vaddq_u16(cdf_offset, delta);
215 vst1q_u16(cdf, cdf_vec);
216 cdf[symbol_count] = count + static_cast<uint16_t>(count < 32);
217 }
218
UpdateCdf7(uint16_t * const cdf,const int symbol)219 void UpdateCdf7(uint16_t* const cdf, const int symbol) {
220 UpdateCdf7To9<7>(cdf, symbol);
221 }
222
UpdateCdf8(uint16_t * const cdf,const int symbol)223 void UpdateCdf8(uint16_t* const cdf, const int symbol) {
224 UpdateCdf7To9<8>(cdf, symbol);
225 }
226
UpdateCdf9(uint16_t * const cdf,const int symbol)227 void UpdateCdf9(uint16_t* const cdf, const int symbol) {
228 UpdateCdf7To9<9>(cdf, symbol);
229 }
230
231 // See UpdateCdf5 for implementation details.
UpdateCdf11(uint16_t * const cdf,const int symbol)232 void UpdateCdf11(uint16_t* const cdf, const int symbol) {
233 uint16x8_t cdf_vec = vld1q_u16(cdf + 2);
234 const uint16_t count = cdf[11];
235 cdf[11] = count + static_cast<uint16_t>(count < 32);
236 const int rate = (count >> 4) + 5;
237 if (symbol > 1) {
238 cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
239 cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
240 const uint16x8_t cdf_max_probability = vdupq_n_u16(kCdfMaxProbability);
241 const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
242 const int16x8_t negative_rate = vdupq_n_s16(-rate);
243 const uint16x8_t index = vcombine_u16(vcreate_u16(0x0005000400030002),
244 vcreate_u16(0x0009000800070006));
245 const uint16x8_t mask = vcgeq_u16(index, symbol_vec);
246 const uint16x8_t a = vorrq_u16(mask, cdf_max_probability);
247 const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec));
248 const uint16x8_t cdf_offset = vsubq_u16(cdf_vec, mask);
249 const uint16x8_t delta =
250 vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
251 cdf_vec = vaddq_u16(cdf_offset, delta);
252 vst1q_u16(cdf + 2, cdf_vec);
253 } else {
254 if (symbol != 0) {
255 cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
256 cdf[1] -= cdf[1] >> rate;
257 } else {
258 cdf[0] -= cdf[0] >> rate;
259 cdf[1] -= cdf[1] >> rate;
260 }
261 const int16x8_t negative_rate = vdupq_n_s16(-rate);
262 const uint16x8_t delta = vshlq_u16(cdf_vec, negative_rate);
263 cdf_vec = vsubq_u16(cdf_vec, delta);
264 vst1q_u16(cdf + 2, cdf_vec);
265 }
266 }
267
268 // See UpdateCdf5 for implementation details.
UpdateCdf13(uint16_t * const cdf,const int symbol)269 void UpdateCdf13(uint16_t* const cdf, const int symbol) {
270 uint16x8_t cdf_vec0 = vld1q_u16(cdf);
271 uint16x8_t cdf_vec1 = vld1q_u16(cdf + 4);
272 const uint16_t count = cdf[13];
273 const int rate = (count >> 4) + 5;
274 const uint16x8_t cdf_max_probability = vdupq_n_u16(kCdfMaxProbability);
275 const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
276 const int16x8_t negative_rate = vdupq_n_s16(-rate);
277
278 uint16x8_t index = vcombine_u16(vcreate_u16(0x0003000200010000),
279 vcreate_u16(0x0007000600050004));
280 uint16x8_t mask = vcgeq_u16(index, symbol_vec);
281 uint16x8_t a = vorrq_u16(mask, cdf_max_probability);
282 int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec0));
283 uint16x8_t cdf_offset = vsubq_u16(cdf_vec0, mask);
284 uint16x8_t delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
285 cdf_vec0 = vaddq_u16(cdf_offset, delta);
286 vst1q_u16(cdf, cdf_vec0);
287
288 index = vcombine_u16(vcreate_u16(0x0007000600050004),
289 vcreate_u16(0x000b000a00090008));
290 mask = vcgeq_u16(index, symbol_vec);
291 a = vorrq_u16(mask, cdf_max_probability);
292 diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec1));
293 cdf_offset = vsubq_u16(cdf_vec1, mask);
294 delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
295 cdf_vec1 = vaddq_u16(cdf_offset, delta);
296 vst1q_u16(cdf + 4, cdf_vec1);
297
298 cdf[13] = count + static_cast<uint16_t>(count < 32);
299 }
300
301 // See UpdateCdf5 for implementation details.
UpdateCdf16(uint16_t * const cdf,const int symbol)302 void UpdateCdf16(uint16_t* const cdf, const int symbol) {
303 uint16x8_t cdf_vec = vld1q_u16(cdf);
304 const uint16_t count = cdf[16];
305 const int rate = (count >> 4) + 5;
306 const uint16x8_t cdf_max_probability = vdupq_n_u16(kCdfMaxProbability);
307 const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
308 const int16x8_t negative_rate = vdupq_n_s16(-rate);
309
310 uint16x8_t index = vcombine_u16(vcreate_u16(0x0003000200010000),
311 vcreate_u16(0x0007000600050004));
312 uint16x8_t mask = vcgeq_u16(index, symbol_vec);
313 uint16x8_t a = vorrq_u16(mask, cdf_max_probability);
314 int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec));
315 uint16x8_t cdf_offset = vsubq_u16(cdf_vec, mask);
316 uint16x8_t delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
317 cdf_vec = vaddq_u16(cdf_offset, delta);
318 vst1q_u16(cdf, cdf_vec);
319
320 cdf_vec = vld1q_u16(cdf + 8);
321 index = vcombine_u16(vcreate_u16(0x000b000a00090008),
322 vcreate_u16(0x000f000e000d000c));
323 mask = vcgeq_u16(index, symbol_vec);
324 a = vorrq_u16(mask, cdf_max_probability);
325 diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec));
326 cdf_offset = vsubq_u16(cdf_vec, mask);
327 delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
328 cdf_vec = vaddq_u16(cdf_offset, delta);
329 vst1q_u16(cdf + 8, cdf_vec);
330
331 cdf[16] = count + static_cast<uint16_t>(count < 32);
332 }
333
334 #else // !LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
335
336 #if LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
337
LoadLo8(const void * a)338 inline __m128i LoadLo8(const void* a) {
339 return _mm_loadl_epi64(static_cast<const __m128i*>(a));
340 }
341
LoadUnaligned16(const void * a)342 inline __m128i LoadUnaligned16(const void* a) {
343 return _mm_loadu_si128(static_cast<const __m128i*>(a));
344 }
345
StoreLo8(void * a,const __m128i v)346 inline void StoreLo8(void* a, const __m128i v) {
347 _mm_storel_epi64(static_cast<__m128i*>(a), v);
348 }
349
StoreUnaligned16(void * a,const __m128i v)350 inline void StoreUnaligned16(void* a, const __m128i v) {
351 _mm_storeu_si128(static_cast<__m128i*>(a), v);
352 }
353
UpdateCdf5(uint16_t * const cdf,const int symbol)354 void UpdateCdf5(uint16_t* const cdf, const int symbol) {
355 __m128i cdf_vec = LoadLo8(cdf);
356 const uint16_t count = cdf[5];
357 const int rate = (count >> 4) + 5;
358 const __m128i cdf_max_probability =
359 _mm_shufflelo_epi16(_mm_cvtsi32_si128(kCdfMaxProbability), 0);
360 const __m128i index = _mm_set_epi32(0x0, 0x0, 0x00040003, 0x00020001);
361 const __m128i symbol_vec = _mm_shufflelo_epi16(_mm_cvtsi32_si128(symbol), 0);
362 // i >= symbol.
363 const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
364 // i < symbol: 32768, i >= symbol: 65535.
365 const __m128i a = _mm_or_si128(mask, cdf_max_probability);
366 // i < symbol: 32768 - cdf, i >= symbol: 65535 - cdf.
367 const __m128i diff = _mm_sub_epi16(a, cdf_vec);
368 // i < symbol: cdf - 0, i >= symbol: cdf - 65535.
369 const __m128i cdf_offset = _mm_sub_epi16(cdf_vec, mask);
370 // i < symbol: (32768 - cdf) >> rate, i >= symbol: (65535 (-1) - cdf) >> rate.
371 const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
372 // i < symbol: (cdf - 0) + ((32768 - cdf) >> rate).
373 // i >= symbol: (cdf - 65535) + ((65535 - cdf) >> rate).
374 cdf_vec = _mm_add_epi16(cdf_offset, delta);
375 StoreLo8(cdf, cdf_vec);
376 cdf[5] = count + static_cast<uint16_t>(count < 32);
377 }
378
379 // This version works for |symbol_count| = 7, 8, or 9.
380 // See UpdateCdf5 for implementation details.
381 template <int symbol_count>
UpdateCdf7To9(uint16_t * const cdf,const int symbol)382 void UpdateCdf7To9(uint16_t* const cdf, const int symbol) {
383 static_assert(symbol_count >= 7 && symbol_count <= 9, "");
384 __m128i cdf_vec = LoadUnaligned16(cdf);
385 const uint16_t count = cdf[symbol_count];
386 const int rate = (count >> 4) + 5;
387 const __m128i cdf_max_probability =
388 _mm_set1_epi16(static_cast<int16_t>(kCdfMaxProbability));
389 const __m128i index =
390 _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001);
391 const __m128i symbol_vec = _mm_set1_epi16(static_cast<int16_t>(symbol));
392 const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
393 const __m128i a = _mm_or_si128(mask, cdf_max_probability);
394 const __m128i diff = _mm_sub_epi16(a, cdf_vec);
395 const __m128i cdf_offset = _mm_sub_epi16(cdf_vec, mask);
396 const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
397 cdf_vec = _mm_add_epi16(cdf_offset, delta);
398 StoreUnaligned16(cdf, cdf_vec);
399 cdf[symbol_count] = count + static_cast<uint16_t>(count < 32);
400 }
401
UpdateCdf7(uint16_t * const cdf,const int symbol)402 void UpdateCdf7(uint16_t* const cdf, const int symbol) {
403 UpdateCdf7To9<7>(cdf, symbol);
404 }
405
UpdateCdf8(uint16_t * const cdf,const int symbol)406 void UpdateCdf8(uint16_t* const cdf, const int symbol) {
407 UpdateCdf7To9<8>(cdf, symbol);
408 }
409
UpdateCdf9(uint16_t * const cdf,const int symbol)410 void UpdateCdf9(uint16_t* const cdf, const int symbol) {
411 UpdateCdf7To9<9>(cdf, symbol);
412 }
413
414 // See UpdateCdf5 for implementation details.
UpdateCdf11(uint16_t * const cdf,const int symbol)415 void UpdateCdf11(uint16_t* const cdf, const int symbol) {
416 __m128i cdf_vec = LoadUnaligned16(cdf + 2);
417 const uint16_t count = cdf[11];
418 cdf[11] = count + static_cast<uint16_t>(count < 32);
419 const int rate = (count >> 4) + 5;
420 if (symbol > 1) {
421 cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
422 cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
423 const __m128i cdf_max_probability =
424 _mm_set1_epi16(static_cast<int16_t>(kCdfMaxProbability));
425 const __m128i index =
426 _mm_set_epi32(0x000a0009, 0x00080007, 0x00060005, 0x00040003);
427 const __m128i symbol_vec = _mm_set1_epi16(static_cast<int16_t>(symbol));
428 const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
429 const __m128i a = _mm_or_si128(mask, cdf_max_probability);
430 const __m128i diff = _mm_sub_epi16(a, cdf_vec);
431 const __m128i cdf_offset = _mm_sub_epi16(cdf_vec, mask);
432 const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
433 cdf_vec = _mm_add_epi16(cdf_offset, delta);
434 StoreUnaligned16(cdf + 2, cdf_vec);
435 } else {
436 if (symbol != 0) {
437 cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
438 cdf[1] -= cdf[1] >> rate;
439 } else {
440 cdf[0] -= cdf[0] >> rate;
441 cdf[1] -= cdf[1] >> rate;
442 }
443 const __m128i delta = _mm_sra_epi16(cdf_vec, _mm_cvtsi32_si128(rate));
444 cdf_vec = _mm_sub_epi16(cdf_vec, delta);
445 StoreUnaligned16(cdf + 2, cdf_vec);
446 }
447 }
448
449 // See UpdateCdf5 for implementation details.
UpdateCdf13(uint16_t * const cdf,const int symbol)450 void UpdateCdf13(uint16_t* const cdf, const int symbol) {
451 __m128i cdf_vec0 = LoadLo8(cdf);
452 __m128i cdf_vec1 = LoadUnaligned16(cdf + 4);
453 const uint16_t count = cdf[13];
454 const int rate = (count >> 4) + 5;
455 const __m128i cdf_max_probability =
456 _mm_set1_epi16(static_cast<int16_t>(kCdfMaxProbability));
457 const __m128i symbol_vec = _mm_set1_epi16(static_cast<int16_t>(symbol));
458
459 const __m128i index = _mm_set_epi32(0x0, 0x0, 0x00040003, 0x00020001);
460 const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
461 const __m128i a = _mm_or_si128(mask, cdf_max_probability);
462 const __m128i diff = _mm_sub_epi16(a, cdf_vec0);
463 const __m128i cdf_offset = _mm_sub_epi16(cdf_vec0, mask);
464 const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
465 cdf_vec0 = _mm_add_epi16(cdf_offset, delta);
466 StoreLo8(cdf, cdf_vec0);
467
468 const __m128i index1 =
469 _mm_set_epi32(0x000c000b, 0x000a0009, 0x00080007, 0x00060005);
470 const __m128i mask1 = _mm_cmpgt_epi16(index1, symbol_vec);
471 const __m128i a1 = _mm_or_si128(mask1, cdf_max_probability);
472 const __m128i diff1 = _mm_sub_epi16(a1, cdf_vec1);
473 const __m128i cdf_offset1 = _mm_sub_epi16(cdf_vec1, mask1);
474 const __m128i delta1 = _mm_sra_epi16(diff1, _mm_cvtsi32_si128(rate));
475 cdf_vec1 = _mm_add_epi16(cdf_offset1, delta1);
476 StoreUnaligned16(cdf + 4, cdf_vec1);
477
478 cdf[13] = count + static_cast<uint16_t>(count < 32);
479 }
480
UpdateCdf16(uint16_t * const cdf,const int symbol)481 void UpdateCdf16(uint16_t* const cdf, const int symbol) {
482 __m128i cdf_vec0 = LoadUnaligned16(cdf);
483 const uint16_t count = cdf[16];
484 const int rate = (count >> 4) + 5;
485 const __m128i cdf_max_probability =
486 _mm_set1_epi16(static_cast<int16_t>(kCdfMaxProbability));
487 const __m128i symbol_vec = _mm_set1_epi16(static_cast<int16_t>(symbol));
488
489 const __m128i index =
490 _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001);
491 const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
492 const __m128i a = _mm_or_si128(mask, cdf_max_probability);
493 const __m128i diff = _mm_sub_epi16(a, cdf_vec0);
494 const __m128i cdf_offset = _mm_sub_epi16(cdf_vec0, mask);
495 const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
496 cdf_vec0 = _mm_add_epi16(cdf_offset, delta);
497 StoreUnaligned16(cdf, cdf_vec0);
498
499 __m128i cdf_vec1 = LoadUnaligned16(cdf + 8);
500 const __m128i index1 =
501 _mm_set_epi32(0x0010000f, 0x000e000d, 0x000c000b, 0x000a0009);
502 const __m128i mask1 = _mm_cmpgt_epi16(index1, symbol_vec);
503 const __m128i a1 = _mm_or_si128(mask1, cdf_max_probability);
504 const __m128i diff1 = _mm_sub_epi16(a1, cdf_vec1);
505 const __m128i cdf_offset1 = _mm_sub_epi16(cdf_vec1, mask1);
506 const __m128i delta1 = _mm_sra_epi16(diff1, _mm_cvtsi32_si128(rate));
507 cdf_vec1 = _mm_add_epi16(cdf_offset1, delta1);
508 StoreUnaligned16(cdf + 8, cdf_vec1);
509
510 cdf[16] = count + static_cast<uint16_t>(count < 32);
511 }
512
513 #else // !LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
514
UpdateCdf5(uint16_t * const cdf,const int symbol)515 void UpdateCdf5(uint16_t* const cdf, const int symbol) {
516 UpdateCdf(cdf, 5, symbol);
517 }
518
UpdateCdf7(uint16_t * const cdf,const int symbol)519 void UpdateCdf7(uint16_t* const cdf, const int symbol) {
520 UpdateCdf(cdf, 7, symbol);
521 }
522
UpdateCdf8(uint16_t * const cdf,const int symbol)523 void UpdateCdf8(uint16_t* const cdf, const int symbol) {
524 UpdateCdf(cdf, 8, symbol);
525 }
526
UpdateCdf9(uint16_t * const cdf,const int symbol)527 void UpdateCdf9(uint16_t* const cdf, const int symbol) {
528 UpdateCdf(cdf, 9, symbol);
529 }
530
UpdateCdf11(uint16_t * const cdf,const int symbol)531 void UpdateCdf11(uint16_t* const cdf, const int symbol) {
532 UpdateCdf(cdf, 11, symbol);
533 }
534
UpdateCdf13(uint16_t * const cdf,const int symbol)535 void UpdateCdf13(uint16_t* const cdf, const int symbol) {
536 UpdateCdf(cdf, 13, symbol);
537 }
538
UpdateCdf16(uint16_t * const cdf,const int symbol)539 void UpdateCdf16(uint16_t* const cdf, const int symbol) {
540 UpdateCdf(cdf, 16, symbol);
541 }
542
543 #endif // LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
544 #endif // LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
545
HostToBigEndian(const DaalaBitReader::WindowSize x)546 inline DaalaBitReader::WindowSize HostToBigEndian(
547 const DaalaBitReader::WindowSize x) {
548 static_assert(sizeof(x) == 4 || sizeof(x) == 8, "");
549 #if defined(__GNUC__)
550 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
551 return (sizeof(x) == 8) ? __builtin_bswap64(x) : __builtin_bswap32(x);
552 #else
553 return x;
554 #endif
555 #elif defined(_WIN32)
556 // Note Windows targets are assumed to be little endian.
557 return static_cast<DaalaBitReader::WindowSize>(
558 (sizeof(x) == 8) ? _byteswap_uint64(static_cast<unsigned __int64>(x))
559 : _byteswap_ulong(static_cast<unsigned long>(x)));
560 #else
561 #error Unknown compiler!
562 #endif // defined(__GNUC__)
563 }
564
565 } // namespace
566
567 #if !LIBGAV1_CXX17
568 constexpr int DaalaBitReader::kWindowSize; // static.
569 #endif
570
DaalaBitReader(const uint8_t * data,size_t size,bool allow_update_cdf)571 DaalaBitReader::DaalaBitReader(const uint8_t* data, size_t size,
572 bool allow_update_cdf)
573 : data_(data),
574 data_end_(data + size),
575 data_memcpy_end_((size >= sizeof(WindowSize))
576 ? data + size - sizeof(WindowSize) + 1
577 : data),
578 allow_update_cdf_(allow_update_cdf),
579 values_in_range_(kCdfMaxProbability) {
580 if (data_ < data_memcpy_end_) {
581 // This is a simplified version of PopulateBits() which loads 8 extra bits
582 // and skips the unnecessary shifts of value and window_diff_.
583 WindowSize value;
584 memcpy(&value, data_, sizeof(value));
585 data_ += sizeof(value);
586 window_diff_ = HostToBigEndian(value) ^ -1;
587 // Note the initial value of bits_ is larger than kMaxCachedBits as it's
588 // used to restore the most significant 0 bit that would be present after
589 // PopulateBits() when we extract the first symbol value.
590 // As shown in Section 8.2.2 Initialization process for symbol decoder,
591 // which uses a fixed offset to read the symbol values, the most
592 // significant bit is always 0:
593 // The variable numBits is set equal to Min( sz * 8, 15).
594 // The variable buf is read using the f(numBits) parsing process.
595 // The variable paddedBuf is set equal to ( buf << (15 - numBits) ).
596 // The variable SymbolValue is set to ((1 << 15) - 1) ^ paddedBuf.
597 bits_ = kWindowSize - 15;
598 return;
599 }
600 window_diff_ = 0;
601 bits_ = -15;
602 PopulateBits();
603 }
604
605 // This is similar to the ReadSymbol() implementation but it is optimized based
606 // on the following facts:
607 // * The probability is fixed at half. So some multiplications can be replaced
608 // with bit operations.
609 // * Symbol count is fixed at 2.
ReadBit()610 int DaalaBitReader::ReadBit() {
611 const uint32_t curr =
612 ((values_in_range_ & kReadBitMask) >> 1) + kMinimumProbabilityPerSymbol;
613 const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
614 int bit = 1;
615 if (symbol_value >= curr) {
616 values_in_range_ -= curr;
617 window_diff_ -= static_cast<WindowSize>(curr) << bits_;
618 bit = 0;
619 } else {
620 values_in_range_ = curr;
621 }
622 NormalizeRange();
623 return bit;
624 }
625
ReadLiteral(int num_bits)626 int64_t DaalaBitReader::ReadLiteral(int num_bits) {
627 assert(num_bits <= 32);
628 assert(num_bits > 0);
629 uint32_t literal = 0;
630 int bit = num_bits - 1;
631 do {
632 // ARM can combine a shift operation with a constant number of bits with
633 // some other operations, such as the OR operation.
634 // Here is an ARM disassembly example:
635 // orr w1, w0, w1, lsl #1
636 // which left shifts register w1 by 1 bit and OR the shift result with
637 // register w0.
638 // The next 2 lines are equivalent to:
639 // literal |= static_cast<uint32_t>(ReadBit()) << bit;
640 literal <<= 1;
641 literal |= static_cast<uint32_t>(ReadBit());
642 } while (--bit >= 0);
643 return literal;
644 }
645
ReadSymbol(uint16_t * const cdf,int symbol_count)646 int DaalaBitReader::ReadSymbol(uint16_t* const cdf, int symbol_count) {
647 const int symbol = ReadSymbolImpl(cdf, symbol_count);
648 if (allow_update_cdf_) {
649 UpdateCdf(cdf, symbol_count, symbol);
650 }
651 return symbol;
652 }
653
ReadSymbol(uint16_t * cdf)654 bool DaalaBitReader::ReadSymbol(uint16_t* cdf) {
655 assert(cdf[1] == 0);
656 const bool symbol = ReadSymbolImpl(cdf[0]) != 0;
657 if (allow_update_cdf_) {
658 const uint16_t count = cdf[2];
659 // rate is computed in the spec as:
660 // 3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
661 // In this case N is 2 and cdf[N] is |count|. So the equation becomes:
662 // 4 + (count > 15) + (count > 31)
663 // Note that the largest value for count is 32 (it is not incremented beyond
664 // 32). So using that information:
665 // count >> 4 is 0 for count from 0 to 15.
666 // count >> 4 is 1 for count from 16 to 31.
667 // count >> 4 is 2 for count == 32.
668 // Now, the equation becomes:
669 // 4 + (count >> 4).
670 // Since (count >> 4) can only be 0 or 1 or 2, the addition can be replaced
671 // with bitwise or. So the final equation is:
672 // 4 | (count >> 4).
673 const int rate = 4 | (count >> 4);
674 if (symbol) {
675 cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
676 } else {
677 cdf[0] -= cdf[0] >> rate;
678 }
679 cdf[2] += static_cast<uint16_t>(count < 32);
680 }
681 return symbol;
682 }
683
ReadSymbolWithoutCdfUpdate(uint16_t cdf)684 bool DaalaBitReader::ReadSymbolWithoutCdfUpdate(uint16_t cdf) {
685 return ReadSymbolImpl(cdf) != 0;
686 }
687
688 template <int symbol_count>
ReadSymbol(uint16_t * const cdf)689 int DaalaBitReader::ReadSymbol(uint16_t* const cdf) {
690 static_assert(symbol_count >= 3 && symbol_count <= 16, "");
691 if (symbol_count == 3 || symbol_count == 4) {
692 return ReadSymbol3Or4(cdf, symbol_count);
693 }
694 int symbol;
695 if (symbol_count == 8) {
696 symbol = ReadSymbolImpl8(cdf);
697 } else if (symbol_count <= 13) {
698 symbol = ReadSymbolImpl(cdf, symbol_count);
699 } else {
700 symbol = ReadSymbolImplBinarySearch(cdf, symbol_count);
701 }
702 if (allow_update_cdf_) {
703 if (symbol_count == 5) {
704 UpdateCdf5(cdf, symbol);
705 } else if (symbol_count == 7) {
706 UpdateCdf7(cdf, symbol);
707 } else if (symbol_count == 8) {
708 UpdateCdf8(cdf, symbol);
709 } else if (symbol_count == 9) {
710 UpdateCdf9(cdf, symbol);
711 } else if (symbol_count == 11) {
712 UpdateCdf11(cdf, symbol);
713 } else if (symbol_count == 13) {
714 UpdateCdf13(cdf, symbol);
715 } else if (symbol_count == 16) {
716 UpdateCdf16(cdf, symbol);
717 } else {
718 UpdateCdf(cdf, symbol_count, symbol);
719 }
720 }
721 return symbol;
722 }
723
ReadSymbolImpl(const uint16_t * const cdf,int symbol_count)724 int DaalaBitReader::ReadSymbolImpl(const uint16_t* const cdf,
725 int symbol_count) {
726 assert(cdf[symbol_count - 1] == 0);
727 --symbol_count;
728 uint32_t curr = values_in_range_;
729 int symbol = -1;
730 uint32_t prev;
731 const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
732 uint32_t delta = kMinimumProbabilityPerSymbol * symbol_count;
733 // Search through the |cdf| array to determine where the scaled cdf value and
734 // |symbol_value| cross over.
735 do {
736 prev = curr;
737 curr = (((values_in_range_ >> 8) * (cdf[++symbol] >> kCdfPrecision)) >> 1) +
738 delta;
739 delta -= kMinimumProbabilityPerSymbol;
740 } while (symbol_value < curr);
741 values_in_range_ = prev - curr;
742 window_diff_ -= static_cast<WindowSize>(curr) << bits_;
743 NormalizeRange();
744 return symbol;
745 }
746
ReadSymbolImplBinarySearch(const uint16_t * const cdf,int symbol_count)747 int DaalaBitReader::ReadSymbolImplBinarySearch(const uint16_t* const cdf,
748 int symbol_count) {
749 assert(cdf[symbol_count - 1] == 0);
750 assert(symbol_count > 1 && symbol_count <= 16);
751 --symbol_count;
752 const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
753 // Search through the |cdf| array to determine where the scaled cdf value and
754 // |symbol_value| cross over. Since the CDFs are sorted, we can use binary
755 // search to do this. Let |symbol| be the index of the first |cdf| array
756 // entry whose scaled cdf value is less than or equal to |symbol_value|. The
757 // binary search maintains the invariant:
758 // low <= symbol <= high + 1
759 // and terminates when low == high + 1.
760 int low = 0;
761 int high = symbol_count - 1;
762 // The binary search maintains the invariants that |prev| is the scaled cdf
763 // value for low - 1 and |curr| is the scaled cdf value for high + 1. (By
764 // convention, the scaled cdf value for -1 is values_in_range_.) When the
765 // binary search terminates, |prev| is the scaled cdf value for symbol - 1
766 // and |curr| is the scaled cdf value for |symbol|.
767 uint32_t prev = values_in_range_;
768 uint32_t curr = 0;
769 const uint32_t values_in_range_shifted = values_in_range_ >> 8;
770 do {
771 const int mid = DivideBy2(low + high);
772 const uint32_t scaled_cdf =
773 ScaleCdf(values_in_range_shifted, cdf, mid, symbol_count);
774 if (symbol_value < scaled_cdf) {
775 low = mid + 1;
776 prev = scaled_cdf;
777 } else {
778 high = mid - 1;
779 curr = scaled_cdf;
780 }
781 } while (low <= high);
782 assert(low == high + 1);
783 // At this point, |low| is the symbol that has been decoded.
784 values_in_range_ = prev - curr;
785 window_diff_ -= static_cast<WindowSize>(curr) << bits_;
786 NormalizeRange();
787 return low;
788 }
789
ReadSymbolImpl(uint16_t cdf)790 int DaalaBitReader::ReadSymbolImpl(uint16_t cdf) {
791 const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
792 const uint32_t curr =
793 (((values_in_range_ >> 8) * (cdf >> kCdfPrecision)) >> 1) +
794 kMinimumProbabilityPerSymbol;
795 const int symbol = static_cast<int>(symbol_value < curr);
796 if (symbol == 1) {
797 values_in_range_ = curr;
798 } else {
799 values_in_range_ -= curr;
800 window_diff_ -= static_cast<WindowSize>(curr) << bits_;
801 }
802 NormalizeRange();
803 return symbol;
804 }
805
806 // Equivalent to ReadSymbol(cdf, [3,4]), with the ReadSymbolImpl and UpdateCdf
807 // calls inlined.
ReadSymbol3Or4(uint16_t * const cdf,const int symbol_count)808 int DaalaBitReader::ReadSymbol3Or4(uint16_t* const cdf,
809 const int symbol_count) {
810 assert(cdf[symbol_count - 1] == 0);
811 uint32_t curr = values_in_range_;
812 uint32_t prev;
813 const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
814 uint32_t delta = kMinimumProbabilityPerSymbol * (symbol_count - 1);
815 const uint32_t values_in_range_shifted = values_in_range_ >> 8;
816
817 // Search through the |cdf| array to determine where the scaled cdf value and
818 // |symbol_value| cross over. If allow_update_cdf_ is true, update the |cdf|
819 // array.
820 //
821 // The original code is:
822 //
823 // int symbol = -1;
824 // do {
825 // prev = curr;
826 // curr =
827 // ((values_in_range_shifted * (cdf[++symbol] >> kCdfPrecision)) >> 1)
828 // + delta;
829 // delta -= kMinimumProbabilityPerSymbol;
830 // } while (symbol_value < curr);
831 // if (allow_update_cdf_) {
832 // UpdateCdf(cdf, [3,4], symbol);
833 // }
834 //
835 // The do-while loop is unrolled with three or four iterations, and the
836 // UpdateCdf call is inlined and merged into the iterations.
837 int symbol = 0;
838 // Iteration 0.
839 prev = curr;
840 curr =
841 ((values_in_range_shifted * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
842 if (symbol_value >= curr) {
843 // symbol == 0.
844 if (allow_update_cdf_) {
845 // Inlined version of UpdateCdf(cdf, [3,4], /*symbol=*/0).
846 const uint16_t count = cdf[symbol_count];
847 cdf[symbol_count] += static_cast<uint16_t>(count < 32);
848 const int rate = (count >> 4) + 4 + static_cast<int>(symbol_count == 4);
849 if (symbol_count == 4) {
850 #if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
851 // 1. On Motorola Moto G5 Plus (running 32-bit Android 8.1.0), the ARM
852 // NEON code is slower. Consider using the C version if __arm__ is
853 // defined.
854 // 2. The ARM NEON code (compiled for arm64) is slightly slower on
855 // Samsung Galaxy S8+ (SM-G955FD).
856 uint16x4_t cdf_vec = vld1_u16(cdf);
857 const int16x4_t negative_rate = vdup_n_s16(-rate);
858 const uint16x4_t delta = vshl_u16(cdf_vec, negative_rate);
859 cdf_vec = vsub_u16(cdf_vec, delta);
860 vst1_u16(cdf, cdf_vec);
861 #elif LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
862 __m128i cdf_vec = LoadLo8(cdf);
863 const __m128i delta = _mm_sra_epi16(cdf_vec, _mm_cvtsi32_si128(rate));
864 cdf_vec = _mm_sub_epi16(cdf_vec, delta);
865 StoreLo8(cdf, cdf_vec);
866 #else // !LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
867 cdf[0] -= cdf[0] >> rate;
868 cdf[1] -= cdf[1] >> rate;
869 cdf[2] -= cdf[2] >> rate;
870 #endif
871 } else { // symbol_count == 3.
872 cdf[0] -= cdf[0] >> rate;
873 cdf[1] -= cdf[1] >> rate;
874 }
875 }
876 goto found;
877 }
878 ++symbol;
879 delta -= kMinimumProbabilityPerSymbol;
880 // Iteration 1.
881 prev = curr;
882 curr =
883 ((values_in_range_shifted * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
884 if (symbol_value >= curr) {
885 // symbol == 1.
886 if (allow_update_cdf_) {
887 // Inlined version of UpdateCdf(cdf, [3,4], /*symbol=*/1).
888 const uint16_t count = cdf[symbol_count];
889 cdf[symbol_count] += static_cast<uint16_t>(count < 32);
890 const int rate = (count >> 4) + 4 + static_cast<int>(symbol_count == 4);
891 cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
892 cdf[1] -= cdf[1] >> rate;
893 if (symbol_count == 4) cdf[2] -= cdf[2] >> rate;
894 }
895 goto found;
896 }
897 ++symbol;
898 if (symbol_count == 4) {
899 delta -= kMinimumProbabilityPerSymbol;
900 // Iteration 2.
901 prev = curr;
902 curr = ((values_in_range_shifted * (cdf[symbol] >> kCdfPrecision)) >> 1) +
903 delta;
904 if (symbol_value >= curr) {
905 // symbol == 2.
906 if (allow_update_cdf_) {
907 // Inlined version of UpdateCdf(cdf, 4, /*symbol=*/2).
908 const uint16_t count = cdf[4];
909 cdf[4] += static_cast<uint16_t>(count < 32);
910 const int rate = (count >> 4) + 5;
911 cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
912 cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
913 cdf[2] -= cdf[2] >> rate;
914 }
915 goto found;
916 }
917 ++symbol;
918 }
919 // |delta| is 0 for the last iteration.
920 // Iteration 2 (symbol_count == 3) or 3 (symbol_count == 4).
921 prev = curr;
922 // Since cdf[symbol_count - 1] is 0 and |delta| is 0, |curr| is also 0.
923 curr = 0;
924 // symbol == [2,3].
925 if (allow_update_cdf_) {
926 // Inlined version of UpdateCdf(cdf, [3,4], /*symbol=*/[2,3]).
927 const uint16_t count = cdf[symbol_count];
928 cdf[symbol_count] += static_cast<uint16_t>(count < 32);
929 const int rate = (4 | (count >> 4)) + static_cast<int>(symbol_count == 4);
930 if (symbol_count == 4) {
931 #if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
932 // On Motorola Moto G5 Plus (running 32-bit Android 8.1.0), the ARM NEON
933 // code is a tiny bit slower. Consider using the C version if __arm__ is
934 // defined.
935 uint16x4_t cdf_vec = vld1_u16(cdf);
936 const uint16x4_t cdf_max_probability = vdup_n_u16(kCdfMaxProbability);
937 const int16x4_t diff =
938 vreinterpret_s16_u16(vsub_u16(cdf_max_probability, cdf_vec));
939 const int16x4_t negative_rate = vdup_n_s16(-rate);
940 const uint16x4_t delta =
941 vreinterpret_u16_s16(vshl_s16(diff, negative_rate));
942 cdf_vec = vadd_u16(cdf_vec, delta);
943 vst1_u16(cdf, cdf_vec);
944 cdf[3] = 0;
945 #elif LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
946 __m128i cdf_vec = LoadLo8(cdf);
947 const __m128i cdf_max_probability =
948 _mm_shufflelo_epi16(_mm_cvtsi32_si128(kCdfMaxProbability), 0);
949 const __m128i diff = _mm_sub_epi16(cdf_max_probability, cdf_vec);
950 const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
951 cdf_vec = _mm_add_epi16(cdf_vec, delta);
952 StoreLo8(cdf, cdf_vec);
953 cdf[3] = 0;
954 #else // !LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
955 cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
956 cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
957 cdf[2] += (kCdfMaxProbability - cdf[2]) >> rate;
958 #endif
959 } else { // symbol_count == 3.
960 cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
961 cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
962 }
963 }
964 found:
965 // End of unrolled do-while loop.
966
967 values_in_range_ = prev - curr;
968 window_diff_ -= static_cast<WindowSize>(curr) << bits_;
969 NormalizeRange();
970 return symbol;
971 }
972
ReadSymbolImpl8(const uint16_t * const cdf)973 int DaalaBitReader::ReadSymbolImpl8(const uint16_t* const cdf) {
974 assert(cdf[7] == 0);
975 uint32_t curr = values_in_range_;
976 uint32_t prev;
977 const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
978 uint32_t delta = kMinimumProbabilityPerSymbol * 7;
979 // Search through the |cdf| array to determine where the scaled cdf value and
980 // |symbol_value| cross over.
981 //
982 // The original code is:
983 //
984 // int symbol = -1;
985 // do {
986 // prev = curr;
987 // curr =
988 // (((values_in_range_ >> 8) * (cdf[++symbol] >> kCdfPrecision)) >> 1)
989 // + delta;
990 // delta -= kMinimumProbabilityPerSymbol;
991 // } while (symbol_value < curr);
992 //
993 // The do-while loop is unrolled with eight iterations.
994 int symbol = 0;
995
996 #define READ_SYMBOL_ITERATION \
997 prev = curr; \
998 curr = (((values_in_range_ >> 8) * (cdf[symbol] >> kCdfPrecision)) >> 1) + \
999 delta; \
1000 if (symbol_value >= curr) goto found; \
1001 ++symbol; \
1002 delta -= kMinimumProbabilityPerSymbol
1003
1004 READ_SYMBOL_ITERATION; // Iteration 0.
1005 READ_SYMBOL_ITERATION; // Iteration 1.
1006 READ_SYMBOL_ITERATION; // Iteration 2.
1007 READ_SYMBOL_ITERATION; // Iteration 3.
1008 READ_SYMBOL_ITERATION; // Iteration 4.
1009 READ_SYMBOL_ITERATION; // Iteration 5.
1010
1011 // The last two iterations can be simplified, so they don't use the
1012 // READ_SYMBOL_ITERATION macro.
1013 #undef READ_SYMBOL_ITERATION
1014
1015 // Iteration 6.
1016 prev = curr;
1017 curr =
1018 (((values_in_range_ >> 8) * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
1019 if (symbol_value >= curr) goto found; // symbol == 6.
1020 ++symbol;
1021 // |delta| is 0 for the last iteration.
1022 // Iteration 7.
1023 prev = curr;
1024 // Since cdf[7] is 0 and |delta| is 0, |curr| is also 0.
1025 curr = 0;
1026 // symbol == 7.
1027 found:
1028 // End of unrolled do-while loop.
1029
1030 values_in_range_ = prev - curr;
1031 window_diff_ -= static_cast<WindowSize>(curr) << bits_;
1032 NormalizeRange();
1033 return symbol;
1034 }
1035
PopulateBits()1036 void DaalaBitReader::PopulateBits() {
1037 constexpr int kMaxCachedBits = kWindowSize - 16;
1038 #if defined(__aarch64__)
1039 // Fast path: read eight bytes and add the first six bytes to window_diff_.
1040 // This fast path makes the following assumptions.
1041 // 1. We assume that unaligned load of uint64_t is fast.
1042 // 2. When there are enough bytes in data_, the for loop below reads 6 or 7
1043 // bytes depending on the value of bits_. This fast path always reads 6
1044 // bytes, which results in more calls to PopulateBits(). We assume that
1045 // making more calls to a faster PopulateBits() is overall a win.
1046 // NOTE: Although this fast path could also be used on x86_64, it hurts
1047 // performance (measured on Lenovo ThinkStation P920 running Linux). (The
1048 // reason is still unknown.) Therefore this fast path is only used on arm64.
1049 static_assert(kWindowSize == 64, "");
1050 if (data_ < data_memcpy_end_) {
1051 uint64_t value;
1052 // arm64 supports unaligned loads, so this memcpy call is compiled to a
1053 // single ldr instruction.
1054 memcpy(&value, data_, sizeof(value));
1055 data_ += kMaxCachedBits >> 3;
1056 value = HostToBigEndian(value) ^ -1;
1057 value >>= kWindowSize - kMaxCachedBits;
1058 window_diff_ = value | (window_diff_ << kMaxCachedBits);
1059 bits_ += kMaxCachedBits;
1060 return;
1061 }
1062 #endif
1063
1064 const uint8_t* data = data_;
1065 int bits = bits_;
1066 WindowSize window_diff = window_diff_;
1067
1068 int count = kWindowSize - 9 - (bits + 15);
1069 // The fast path above, if compiled, would cause clang 8.0.7 to vectorize
1070 // this loop. Since -15 <= bits_ <= -1, this loop has at most 6 or 7
1071 // iterations when WindowSize is 64 bits. So it is not profitable to
1072 // vectorize this loop. Note that clang 8.0.7 does not vectorize this loop if
1073 // the fast path above is not compiled.
1074
1075 #ifdef __clang__
1076 #pragma clang loop vectorize(disable) interleave(disable)
1077 #endif
1078 for (; count >= 0 && data < data_end_; count -= 8) {
1079 const uint8_t value = *data++ ^ -1;
1080 window_diff = static_cast<WindowSize>(value) | (window_diff << 8);
1081 bits += 8;
1082 }
1083 assert(bits <= kMaxCachedBits);
1084 if (data == data_end_) {
1085 // Shift in some 1s. This is equivalent to providing fake 0 data bits.
1086 window_diff = ((window_diff + 1) << (kMaxCachedBits - bits)) - 1;
1087 bits = kMaxCachedBits;
1088 }
1089
1090 data_ = data;
1091 bits_ = bits;
1092 window_diff_ = window_diff;
1093 }
1094
NormalizeRange()1095 void DaalaBitReader::NormalizeRange() {
1096 const int bits_used = 15 ^ FloorLog2(values_in_range_);
1097 bits_ -= bits_used;
1098 values_in_range_ <<= bits_used;
1099 if (bits_ < 0) PopulateBits();
1100 }
1101
1102 // Explicit instantiations.
1103 template int DaalaBitReader::ReadSymbol<3>(uint16_t* cdf);
1104 template int DaalaBitReader::ReadSymbol<4>(uint16_t* cdf);
1105 template int DaalaBitReader::ReadSymbol<5>(uint16_t* cdf);
1106 template int DaalaBitReader::ReadSymbol<6>(uint16_t* cdf);
1107 template int DaalaBitReader::ReadSymbol<7>(uint16_t* cdf);
1108 template int DaalaBitReader::ReadSymbol<8>(uint16_t* cdf);
1109 template int DaalaBitReader::ReadSymbol<9>(uint16_t* cdf);
1110 template int DaalaBitReader::ReadSymbol<10>(uint16_t* cdf);
1111 template int DaalaBitReader::ReadSymbol<11>(uint16_t* cdf);
1112 template int DaalaBitReader::ReadSymbol<12>(uint16_t* cdf);
1113 template int DaalaBitReader::ReadSymbol<13>(uint16_t* cdf);
1114 template int DaalaBitReader::ReadSymbol<14>(uint16_t* cdf);
1115 template int DaalaBitReader::ReadSymbol<16>(uint16_t* cdf);
1116
1117 } // namespace libgav1
1118