• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 #include <array>
17 #include <cassert>
18 #include <cstdint>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <vector>
23 
24 #include "src/buffer_pool.h"
25 #include "src/dsp/constants.h"
26 #include "src/motion_vector.h"
27 #include "src/obu_parser.h"
28 #include "src/prediction_mask.h"
29 #include "src/symbol_decoder_context.h"
30 #include "src/tile.h"
31 #include "src/utils/array_2d.h"
32 #include "src/utils/bit_mask_set.h"
33 #include "src/utils/block_parameters_holder.h"
34 #include "src/utils/common.h"
35 #include "src/utils/constants.h"
36 #include "src/utils/entropy_decoder.h"
37 #include "src/utils/logging.h"
38 #include "src/utils/segmentation.h"
39 #include "src/utils/segmentation_map.h"
40 #include "src/utils/types.h"
41 
42 namespace libgav1 {
43 namespace {
44 
45 constexpr int kDeltaQSmall = 3;
46 constexpr int kDeltaLfSmall = 3;
47 
48 constexpr uint8_t kIntraYModeContext[kIntraPredictionModesY] = {
49     0, 1, 2, 3, 4, 4, 4, 4, 3, 0, 1, 2, 0};
50 
51 constexpr uint8_t kSizeGroup[kMaxBlockSizes] = {
52     0, 0, 0, 0, 1, 1, 1, 0, 1, 2, 2, 2, 1, 2, 3, 3, 2, 3, 3, 3, 3, 3};
53 
54 constexpr int kCompoundModeNewMvContexts = 5;
55 constexpr uint8_t kCompoundModeContextMap[3][kCompoundModeNewMvContexts] = {
56     {0, 1, 1, 1, 1}, {1, 2, 3, 4, 4}, {4, 4, 5, 6, 7}};
57 
58 enum CflSign : uint8_t {
59   kCflSignZero = 0,
60   kCflSignNegative = 1,
61   kCflSignPositive = 2
62 };
63 
64 // For each possible value of the combined signs (which is read from the
65 // bitstream), this array stores the following: sign_u, sign_v, alpha_u_context,
66 // alpha_v_context. Only positive entries are used. Entry at index i is computed
67 // as follows:
68 // sign_u = i / 3
69 // sign_v = i % 3
70 // alpha_u_context = i - 2
71 // alpha_v_context = (sign_v - 1) * 3 + sign_u
72 constexpr int8_t kCflAlphaLookup[kCflAlphaSignsSymbolCount][4] = {
73     {0, 1, -2, 0}, {0, 2, -1, 3}, {1, 0, 0, -2}, {1, 1, 1, 1},
74     {1, 2, 2, 4},  {2, 0, 3, -1}, {2, 1, 4, 2},  {2, 2, 5, 5},
75 };
76 
77 constexpr BitMaskSet kPredictionModeHasNearMvMask(kPredictionModeNearMv,
78                                                   kPredictionModeNearNearMv,
79                                                   kPredictionModeNearNewMv,
80                                                   kPredictionModeNewNearMv);
81 
82 constexpr BitMaskSet kIsInterIntraModeAllowedMask(kBlock8x8, kBlock8x16,
83                                                   kBlock16x8, kBlock16x16,
84                                                   kBlock16x32, kBlock32x16,
85                                                   kBlock32x32);
86 
IsBackwardReference(ReferenceFrameType type)87 bool IsBackwardReference(ReferenceFrameType type) {
88   return type >= kReferenceFrameBackward && type <= kReferenceFrameAlternate;
89 }
90 
IsSameDirectionReferencePair(ReferenceFrameType type1,ReferenceFrameType type2)91 bool IsSameDirectionReferencePair(ReferenceFrameType type1,
92                                   ReferenceFrameType type2) {
93   return (type1 >= kReferenceFrameBackward) ==
94          (type2 >= kReferenceFrameBackward);
95 }
96 
97 // This is called neg_deinterleave() in the spec.
DecodeSegmentId(int diff,int reference,int max)98 int DecodeSegmentId(int diff, int reference, int max) {
99   if (reference == 0) return diff;
100   if (reference >= max - 1) return max - diff - 1;
101   const int value = ((diff & 1) != 0) ? reference + ((diff + 1) >> 1)
102                                       : reference - (diff >> 1);
103   const int reference2 = (reference << 1);
104   if (reference2 < max) {
105     return (diff <= reference2) ? value : diff;
106   }
107   return (diff <= ((max - reference - 1) << 1)) ? value : max - (diff + 1);
108 }
109 
110 // This is called DrlCtxStack in section 7.10.2.14 of the spec.
111 // In the spec, the weights of all the nearest mvs are incremented by a bonus
112 // weight which is larger than any natural weight, and the weights of the mvs
113 // are compared with this bonus weight to determine their contexts. We replace
114 // this procedure by introducing |nearest_mv_count| in PredictionParameters,
115 // which records the count of the nearest mvs. Since all the nearest mvs are in
116 // the beginning of the mv stack, the |index| of a mv in the mv stack can be
117 // compared with |nearest_mv_count| to get that mv's context.
GetRefMvIndexContext(int nearest_mv_count,int index)118 int GetRefMvIndexContext(int nearest_mv_count, int index) {
119   if (index + 1 < nearest_mv_count) {
120     return 0;
121   }
122   if (index + 1 == nearest_mv_count) {
123     return 1;
124   }
125   return 2;
126 }
127 
128 // Returns true if both the width and height of the block is less than 64.
IsBlockDimensionLessThan64(BlockSize size)129 bool IsBlockDimensionLessThan64(BlockSize size) {
130   return size <= kBlock32x32 && size != kBlock16x64;
131 }
132 
GetUseCompoundReferenceContext(const Tile::Block & block)133 int GetUseCompoundReferenceContext(const Tile::Block& block) {
134   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
135     if (block.IsTopSingle() && block.IsLeftSingle()) {
136       return static_cast<int>(IsBackwardReference(block.TopReference(0))) ^
137              static_cast<int>(IsBackwardReference(block.LeftReference(0)));
138     }
139     if (block.IsTopSingle()) {
140       return 2 + static_cast<int>(IsBackwardReference(block.TopReference(0)) ||
141                                   block.IsTopIntra());
142     }
143     if (block.IsLeftSingle()) {
144       return 2 + static_cast<int>(IsBackwardReference(block.LeftReference(0)) ||
145                                   block.IsLeftIntra());
146     }
147     return 4;
148   }
149   if (block.top_available[kPlaneY]) {
150     return block.IsTopSingle()
151                ? static_cast<int>(IsBackwardReference(block.TopReference(0)))
152                : 3;
153   }
154   if (block.left_available[kPlaneY]) {
155     return block.IsLeftSingle()
156                ? static_cast<int>(IsBackwardReference(block.LeftReference(0)))
157                : 3;
158   }
159   return 1;
160 }
161 
162 // Calculates count0 by calling block.CountReferences() on the frame types from
163 // type0_start to type0_end, inclusive, and summing the results.
164 // Calculates count1 by calling block.CountReferences() on the frame types from
165 // type1_start to type1_end, inclusive, and summing the results.
166 // Compares count0 with count1 and returns 0, 1 or 2.
167 //
168 // See count_refs and ref_count_ctx in 8.3.2.
GetReferenceContext(const Tile::Block & block,ReferenceFrameType type0_start,ReferenceFrameType type0_end,ReferenceFrameType type1_start,ReferenceFrameType type1_end)169 int GetReferenceContext(const Tile::Block& block,
170                         ReferenceFrameType type0_start,
171                         ReferenceFrameType type0_end,
172                         ReferenceFrameType type1_start,
173                         ReferenceFrameType type1_end) {
174   int count0 = 0;
175   int count1 = 0;
176   for (int type = type0_start; type <= type0_end; ++type) {
177     count0 += block.CountReferences(static_cast<ReferenceFrameType>(type));
178   }
179   for (int type = type1_start; type <= type1_end; ++type) {
180     count1 += block.CountReferences(static_cast<ReferenceFrameType>(type));
181   }
182   return (count0 < count1) ? 0 : (count0 == count1 ? 1 : 2);
183 }
184 
185 }  // namespace
186 
ReadSegmentId(const Block & block)187 bool Tile::ReadSegmentId(const Block& block) {
188   int top_left = -1;
189   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
190     top_left =
191         block_parameters_holder_.Find(block.row4x4 - 1, block.column4x4 - 1)
192             ->segment_id;
193   }
194   int top = -1;
195   if (block.top_available[kPlaneY]) {
196     top = block.bp_top->segment_id;
197   }
198   int left = -1;
199   if (block.left_available[kPlaneY]) {
200     left = block.bp_left->segment_id;
201   }
202   int pred;
203   if (top == -1) {
204     pred = (left == -1) ? 0 : left;
205   } else if (left == -1) {
206     pred = top;
207   } else {
208     pred = (top_left == top) ? top : left;
209   }
210   BlockParameters& bp = *block.bp;
211   if (bp.skip) {
212     bp.segment_id = pred;
213     return true;
214   }
215   int context = 0;
216   if (top_left < 0) {
217     context = 0;
218   } else if (top_left == top && top_left == left) {
219     context = 2;
220   } else if (top_left == top || top_left == left || top == left) {
221     context = 1;
222   }
223   uint16_t* const segment_id_cdf =
224       symbol_decoder_context_.segment_id_cdf[context];
225   const int encoded_segment_id =
226       reader_.ReadSymbol<kMaxSegments>(segment_id_cdf);
227   bp.segment_id =
228       DecodeSegmentId(encoded_segment_id, pred,
229                       frame_header_.segmentation.last_active_segment_id + 1);
230   // Check the bitstream conformance requirement in Section 6.10.8 of the spec.
231   if (bp.segment_id < 0 ||
232       bp.segment_id > frame_header_.segmentation.last_active_segment_id) {
233     LIBGAV1_DLOG(
234         ERROR,
235         "Corrupted segment_ids: encoded %d, last active %d, postprocessed %d",
236         encoded_segment_id, frame_header_.segmentation.last_active_segment_id,
237         bp.segment_id);
238     return false;
239   }
240   return true;
241 }
242 
ReadIntraSegmentId(const Block & block)243 bool Tile::ReadIntraSegmentId(const Block& block) {
244   BlockParameters& bp = *block.bp;
245   if (!frame_header_.segmentation.enabled) {
246     bp.segment_id = 0;
247     return true;
248   }
249   return ReadSegmentId(block);
250 }
251 
ReadSkip(const Block & block)252 void Tile::ReadSkip(const Block& block) {
253   BlockParameters& bp = *block.bp;
254   if (frame_header_.segmentation.segment_id_pre_skip &&
255       frame_header_.segmentation.FeatureActive(bp.segment_id,
256                                                kSegmentFeatureSkip)) {
257     bp.skip = true;
258     return;
259   }
260   int context = 0;
261   if (block.top_available[kPlaneY] && block.bp_top->skip) {
262     ++context;
263   }
264   if (block.left_available[kPlaneY] && block.bp_left->skip) {
265     ++context;
266   }
267   uint16_t* const skip_cdf = symbol_decoder_context_.skip_cdf[context];
268   bp.skip = reader_.ReadSymbol(skip_cdf);
269 }
270 
ReadSkipMode(const Block & block)271 void Tile::ReadSkipMode(const Block& block) {
272   BlockParameters& bp = *block.bp;
273   if (!frame_header_.skip_mode_present ||
274       frame_header_.segmentation.FeatureActive(bp.segment_id,
275                                                kSegmentFeatureSkip) ||
276       frame_header_.segmentation.FeatureActive(bp.segment_id,
277                                                kSegmentFeatureReferenceFrame) ||
278       frame_header_.segmentation.FeatureActive(bp.segment_id,
279                                                kSegmentFeatureGlobalMv) ||
280       IsBlockDimension4(block.size)) {
281     bp.skip_mode = false;
282     return;
283   }
284   const int context =
285       (block.left_available[kPlaneY]
286            ? static_cast<int>(block.bp_left->skip_mode)
287            : 0) +
288       (block.top_available[kPlaneY] ? static_cast<int>(block.bp_top->skip_mode)
289                                     : 0);
290   bp.skip_mode =
291       reader_.ReadSymbol(symbol_decoder_context_.skip_mode_cdf[context]);
292 }
293 
ReadCdef(const Block & block)294 void Tile::ReadCdef(const Block& block) {
295   BlockParameters& bp = *block.bp;
296   if (bp.skip || frame_header_.coded_lossless ||
297       !sequence_header_.enable_cdef || frame_header_.allow_intrabc) {
298     return;
299   }
300   const int cdef_size4x4 = kNum4x4BlocksWide[kBlock64x64];
301   const int cdef_mask4x4 = ~(cdef_size4x4 - 1);
302   const int row4x4 = block.row4x4 & cdef_mask4x4;
303   const int column4x4 = block.column4x4 & cdef_mask4x4;
304   const int row = DivideBy16(row4x4);
305   const int column = DivideBy16(column4x4);
306   if (cdef_index_[row][column] == -1) {
307     cdef_index_[row][column] =
308         frame_header_.cdef.bits > 0
309             ? static_cast<int16_t>(reader_.ReadLiteral(frame_header_.cdef.bits))
310             : 0;
311     for (int i = row4x4; i < row4x4 + block.height4x4; i += cdef_size4x4) {
312       for (int j = column4x4; j < column4x4 + block.width4x4;
313            j += cdef_size4x4) {
314         cdef_index_[DivideBy16(i)][DivideBy16(j)] = cdef_index_[row][column];
315       }
316     }
317   }
318 }
319 
ReadAndClipDelta(uint16_t * const cdf,int delta_small,int scale,int min_value,int max_value,int value)320 int Tile::ReadAndClipDelta(uint16_t* const cdf, int delta_small, int scale,
321                            int min_value, int max_value, int value) {
322   int abs = reader_.ReadSymbol<kDeltaSymbolCount>(cdf);
323   if (abs == delta_small) {
324     const int remaining_bit_count =
325         static_cast<int>(reader_.ReadLiteral(3)) + 1;
326     const int abs_remaining_bits =
327         static_cast<int>(reader_.ReadLiteral(remaining_bit_count));
328     abs = abs_remaining_bits + (1 << remaining_bit_count) + 1;
329   }
330   if (abs != 0) {
331     const bool sign = static_cast<bool>(reader_.ReadBit());
332     const int scaled_abs = abs << scale;
333     const int reduced_delta = sign ? -scaled_abs : scaled_abs;
334     value += reduced_delta;
335     value = Clip3(value, min_value, max_value);
336   }
337   return value;
338 }
339 
ReadQuantizerIndexDelta(const Block & block)340 void Tile::ReadQuantizerIndexDelta(const Block& block) {
341   assert(read_deltas_);
342   BlockParameters& bp = *block.bp;
343   if ((block.size == SuperBlockSize() && bp.skip)) {
344     return;
345   }
346   current_quantizer_index_ =
347       ReadAndClipDelta(symbol_decoder_context_.delta_q_cdf, kDeltaQSmall,
348                        frame_header_.delta_q.scale, kMinLossyQuantizer,
349                        kMaxQuantizer, current_quantizer_index_);
350 }
351 
ReadLoopFilterDelta(const Block & block)352 void Tile::ReadLoopFilterDelta(const Block& block) {
353   assert(read_deltas_);
354   BlockParameters& bp = *block.bp;
355   if (!frame_header_.delta_lf.present ||
356       (block.size == SuperBlockSize() && bp.skip)) {
357     return;
358   }
359   int frame_lf_count = 1;
360   if (frame_header_.delta_lf.multi) {
361     frame_lf_count = kFrameLfCount - (PlaneCount() > 1 ? 0 : 2);
362   }
363   bool recompute_deblock_filter_levels = false;
364   for (int i = 0; i < frame_lf_count; ++i) {
365     uint16_t* const delta_lf_abs_cdf =
366         frame_header_.delta_lf.multi
367             ? symbol_decoder_context_.delta_lf_multi_cdf[i]
368             : symbol_decoder_context_.delta_lf_cdf;
369     const int8_t old_delta_lf = delta_lf_[i];
370     delta_lf_[i] = ReadAndClipDelta(
371         delta_lf_abs_cdf, kDeltaLfSmall, frame_header_.delta_lf.scale,
372         -kMaxLoopFilterValue, kMaxLoopFilterValue, delta_lf_[i]);
373     recompute_deblock_filter_levels =
374         recompute_deblock_filter_levels || (old_delta_lf != delta_lf_[i]);
375   }
376   delta_lf_all_zero_ =
377       (delta_lf_[0] | delta_lf_[1] | delta_lf_[2] | delta_lf_[3]) == 0;
378   if (!delta_lf_all_zero_ && recompute_deblock_filter_levels) {
379     post_filter_.ComputeDeblockFilterLevels(delta_lf_, deblock_filter_levels_);
380   }
381 }
382 
ReadPredictionModeY(const Block & block,bool intra_y_mode)383 void Tile::ReadPredictionModeY(const Block& block, bool intra_y_mode) {
384   uint16_t* cdf;
385   if (intra_y_mode) {
386     const PredictionMode top_mode =
387         block.top_available[kPlaneY] ? block.bp_top->y_mode : kPredictionModeDc;
388     const PredictionMode left_mode = block.left_available[kPlaneY]
389                                          ? block.bp_left->y_mode
390                                          : kPredictionModeDc;
391     const int top_context = kIntraYModeContext[top_mode];
392     const int left_context = kIntraYModeContext[left_mode];
393     cdf = symbol_decoder_context_
394               .intra_frame_y_mode_cdf[top_context][left_context];
395   } else {
396     cdf = symbol_decoder_context_.y_mode_cdf[kSizeGroup[block.size]];
397   }
398   block.bp->y_mode = static_cast<PredictionMode>(
399       reader_.ReadSymbol<kIntraPredictionModesY>(cdf));
400 }
401 
ReadIntraAngleInfo(const Block & block,PlaneType plane_type)402 void Tile::ReadIntraAngleInfo(const Block& block, PlaneType plane_type) {
403   BlockParameters& bp = *block.bp;
404   PredictionParameters& prediction_parameters =
405       *block.bp->prediction_parameters;
406   prediction_parameters.angle_delta[plane_type] = 0;
407   const PredictionMode mode =
408       (plane_type == kPlaneTypeY) ? bp.y_mode : bp.uv_mode;
409   if (IsBlockSmallerThan8x8(block.size) || !IsDirectionalMode(mode)) return;
410   uint16_t* const cdf =
411       symbol_decoder_context_.angle_delta_cdf[mode - kPredictionModeVertical];
412   prediction_parameters.angle_delta[plane_type] =
413       reader_.ReadSymbol<kAngleDeltaSymbolCount>(cdf);
414   prediction_parameters.angle_delta[plane_type] -= kMaxAngleDelta;
415 }
416 
ReadCflAlpha(const Block & block)417 void Tile::ReadCflAlpha(const Block& block) {
418   const int signs = reader_.ReadSymbol<kCflAlphaSignsSymbolCount>(
419       symbol_decoder_context_.cfl_alpha_signs_cdf);
420   const int8_t* const cfl_lookup = kCflAlphaLookup[signs];
421   const auto sign_u = static_cast<CflSign>(cfl_lookup[0]);
422   const auto sign_v = static_cast<CflSign>(cfl_lookup[1]);
423   PredictionParameters& prediction_parameters =
424       *block.bp->prediction_parameters;
425   prediction_parameters.cfl_alpha_u = 0;
426   if (sign_u != kCflSignZero) {
427     assert(cfl_lookup[2] >= 0);
428     prediction_parameters.cfl_alpha_u =
429         reader_.ReadSymbol<kCflAlphaSymbolCount>(
430             symbol_decoder_context_.cfl_alpha_cdf[cfl_lookup[2]]) +
431         1;
432     if (sign_u == kCflSignNegative) prediction_parameters.cfl_alpha_u *= -1;
433   }
434   prediction_parameters.cfl_alpha_v = 0;
435   if (sign_v != kCflSignZero) {
436     assert(cfl_lookup[3] >= 0);
437     prediction_parameters.cfl_alpha_v =
438         reader_.ReadSymbol<kCflAlphaSymbolCount>(
439             symbol_decoder_context_.cfl_alpha_cdf[cfl_lookup[3]]) +
440         1;
441     if (sign_v == kCflSignNegative) prediction_parameters.cfl_alpha_v *= -1;
442   }
443 }
444 
ReadPredictionModeUV(const Block & block)445 void Tile::ReadPredictionModeUV(const Block& block) {
446   BlockParameters& bp = *block.bp;
447   bool chroma_from_luma_allowed;
448   if (frame_header_.segmentation.lossless[bp.segment_id]) {
449     chroma_from_luma_allowed = block.residual_size[kPlaneU] == kBlock4x4;
450   } else {
451     chroma_from_luma_allowed = IsBlockDimensionLessThan64(block.size);
452   }
453   uint16_t* const cdf =
454       symbol_decoder_context_
455           .uv_mode_cdf[static_cast<int>(chroma_from_luma_allowed)][bp.y_mode];
456   if (chroma_from_luma_allowed) {
457     bp.uv_mode = static_cast<PredictionMode>(
458         reader_.ReadSymbol<kIntraPredictionModesUV>(cdf));
459   } else {
460     bp.uv_mode = static_cast<PredictionMode>(
461         reader_.ReadSymbol<kIntraPredictionModesUV - 1>(cdf));
462   }
463 }
464 
ReadMotionVectorComponent(const Block & block,const int component)465 int Tile::ReadMotionVectorComponent(const Block& block, const int component) {
466   const int context =
467       static_cast<int>(block.bp->prediction_parameters->use_intra_block_copy);
468   const bool sign = reader_.ReadSymbol(
469       symbol_decoder_context_.mv_sign_cdf[component][context]);
470   const int mv_class = reader_.ReadSymbol<kMvClassSymbolCount>(
471       symbol_decoder_context_.mv_class_cdf[component][context]);
472   int magnitude = 1;
473   int value;
474   uint16_t* fraction_cdf;
475   uint16_t* precision_cdf;
476   if (mv_class == 0) {
477     value = static_cast<int>(reader_.ReadSymbol(
478         symbol_decoder_context_.mv_class0_bit_cdf[component][context]));
479     fraction_cdf = symbol_decoder_context_
480                        .mv_class0_fraction_cdf[component][context][value];
481     precision_cdf = symbol_decoder_context_
482                         .mv_class0_high_precision_cdf[component][context];
483   } else {
484     assert(mv_class <= kMvBitSymbolCount);
485     value = 0;
486     for (int i = 0; i < mv_class; ++i) {
487       const int bit = static_cast<int>(reader_.ReadSymbol(
488           symbol_decoder_context_.mv_bit_cdf[component][context][i]));
489       value |= bit << i;
490     }
491     magnitude += 2 << (mv_class + 2);
492     fraction_cdf = symbol_decoder_context_.mv_fraction_cdf[component][context];
493     precision_cdf =
494         symbol_decoder_context_.mv_high_precision_cdf[component][context];
495   }
496   const int fraction =
497       (frame_header_.force_integer_mv == 0)
498           ? reader_.ReadSymbol<kMvFractionSymbolCount>(fraction_cdf)
499           : 3;
500   const int precision =
501       frame_header_.allow_high_precision_mv
502           ? static_cast<int>(reader_.ReadSymbol(precision_cdf))
503           : 1;
504   magnitude += (value << 3) | (fraction << 1) | precision;
505   return sign ? -magnitude : magnitude;
506 }
507 
ReadMotionVector(const Block & block,int index)508 void Tile::ReadMotionVector(const Block& block, int index) {
509   BlockParameters& bp = *block.bp;
510   const int context =
511       static_cast<int>(block.bp->prediction_parameters->use_intra_block_copy);
512   const auto mv_joint =
513       static_cast<MvJointType>(reader_.ReadSymbol<kNumMvJointTypes>(
514           symbol_decoder_context_.mv_joint_cdf[context]));
515   if (mv_joint == kMvJointTypeHorizontalZeroVerticalNonZero ||
516       mv_joint == kMvJointTypeNonZero) {
517     bp.mv.mv[index].mv[0] = ReadMotionVectorComponent(block, 0);
518   }
519   if (mv_joint == kMvJointTypeHorizontalNonZeroVerticalZero ||
520       mv_joint == kMvJointTypeNonZero) {
521     bp.mv.mv[index].mv[1] = ReadMotionVectorComponent(block, 1);
522   }
523 }
524 
ReadFilterIntraModeInfo(const Block & block)525 void Tile::ReadFilterIntraModeInfo(const Block& block) {
526   BlockParameters& bp = *block.bp;
527   PredictionParameters& prediction_parameters =
528       *block.bp->prediction_parameters;
529   prediction_parameters.use_filter_intra = false;
530   if (!sequence_header_.enable_filter_intra || bp.y_mode != kPredictionModeDc ||
531       bp.palette_mode_info.size[kPlaneTypeY] != 0 ||
532       !IsBlockDimensionLessThan64(block.size)) {
533     return;
534   }
535   prediction_parameters.use_filter_intra = reader_.ReadSymbol(
536       symbol_decoder_context_.use_filter_intra_cdf[block.size]);
537   if (prediction_parameters.use_filter_intra) {
538     prediction_parameters.filter_intra_mode = static_cast<FilterIntraPredictor>(
539         reader_.ReadSymbol<kNumFilterIntraPredictors>(
540             symbol_decoder_context_.filter_intra_mode_cdf));
541   }
542 }
543 
DecodeIntraModeInfo(const Block & block)544 bool Tile::DecodeIntraModeInfo(const Block& block) {
545   BlockParameters& bp = *block.bp;
546   bp.skip = false;
547   if (frame_header_.segmentation.segment_id_pre_skip &&
548       !ReadIntraSegmentId(block)) {
549     return false;
550   }
551   bp.skip_mode = false;
552   ReadSkip(block);
553   if (!frame_header_.segmentation.segment_id_pre_skip &&
554       !ReadIntraSegmentId(block)) {
555     return false;
556   }
557   ReadCdef(block);
558   if (read_deltas_) {
559     ReadQuantizerIndexDelta(block);
560     ReadLoopFilterDelta(block);
561     read_deltas_ = false;
562   }
563   PredictionParameters& prediction_parameters =
564       *block.bp->prediction_parameters;
565   prediction_parameters.use_intra_block_copy = false;
566   if (frame_header_.allow_intrabc) {
567     prediction_parameters.use_intra_block_copy =
568         reader_.ReadSymbol(symbol_decoder_context_.intra_block_copy_cdf);
569   }
570   if (prediction_parameters.use_intra_block_copy) {
571     bp.is_inter = true;
572     bp.reference_frame[0] = kReferenceFrameIntra;
573     bp.reference_frame[1] = kReferenceFrameNone;
574     bp.y_mode = kPredictionModeDc;
575     bp.uv_mode = kPredictionModeDc;
576     prediction_parameters.motion_mode = kMotionModeSimple;
577     prediction_parameters.compound_prediction_type =
578         kCompoundPredictionTypeAverage;
579     bp.palette_mode_info.size[kPlaneTypeY] = 0;
580     bp.palette_mode_info.size[kPlaneTypeUV] = 0;
581     bp.interpolation_filter[0] = kInterpolationFilterBilinear;
582     bp.interpolation_filter[1] = kInterpolationFilterBilinear;
583     MvContexts dummy_mode_contexts;
584     FindMvStack(block, /*is_compound=*/false, &dummy_mode_contexts);
585     return AssignIntraMv(block);
586   }
587   bp.is_inter = false;
588   return ReadIntraBlockModeInfo(block, /*intra_y_mode=*/true);
589 }
590 
ComputePredictedSegmentId(const Block & block) const591 int8_t Tile::ComputePredictedSegmentId(const Block& block) const {
592   // If prev_segment_ids_ is null, treat it as if it pointed to a segmentation
593   // map containing all 0s.
594   if (prev_segment_ids_ == nullptr) return 0;
595 
596   const int x_limit = std::min(frame_header_.columns4x4 - block.column4x4,
597                                static_cast<int>(block.width4x4));
598   const int y_limit = std::min(frame_header_.rows4x4 - block.row4x4,
599                                static_cast<int>(block.height4x4));
600   int8_t id = 7;
601   for (int y = 0; y < y_limit; ++y) {
602     for (int x = 0; x < x_limit; ++x) {
603       const int8_t prev_segment_id =
604           prev_segment_ids_->segment_id(block.row4x4 + y, block.column4x4 + x);
605       id = std::min(id, prev_segment_id);
606     }
607   }
608   return id;
609 }
610 
ReadInterSegmentId(const Block & block,bool pre_skip)611 bool Tile::ReadInterSegmentId(const Block& block, bool pre_skip) {
612   BlockParameters& bp = *block.bp;
613   if (!frame_header_.segmentation.enabled) {
614     bp.segment_id = 0;
615     return true;
616   }
617   if (!frame_header_.segmentation.update_map) {
618     bp.segment_id = ComputePredictedSegmentId(block);
619     return true;
620   }
621   if (pre_skip) {
622     if (!frame_header_.segmentation.segment_id_pre_skip) {
623       bp.segment_id = 0;
624       return true;
625     }
626   } else if (bp.skip) {
627     bp.use_predicted_segment_id = false;
628     return ReadSegmentId(block);
629   }
630   if (frame_header_.segmentation.temporal_update) {
631     const int context =
632         (block.left_available[kPlaneY]
633              ? static_cast<int>(block.bp_left->use_predicted_segment_id)
634              : 0) +
635         (block.top_available[kPlaneY]
636              ? static_cast<int>(block.bp_top->use_predicted_segment_id)
637              : 0);
638     bp.use_predicted_segment_id = reader_.ReadSymbol(
639         symbol_decoder_context_.use_predicted_segment_id_cdf[context]);
640     if (bp.use_predicted_segment_id) {
641       bp.segment_id = ComputePredictedSegmentId(block);
642       return true;
643     }
644   }
645   return ReadSegmentId(block);
646 }
647 
ReadIsInter(const Block & block)648 void Tile::ReadIsInter(const Block& block) {
649   BlockParameters& bp = *block.bp;
650   if (bp.skip_mode) {
651     bp.is_inter = true;
652     return;
653   }
654   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
655                                                kSegmentFeatureReferenceFrame)) {
656     bp.is_inter =
657         frame_header_.segmentation
658             .feature_data[bp.segment_id][kSegmentFeatureReferenceFrame] !=
659         kReferenceFrameIntra;
660     return;
661   }
662   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
663                                                kSegmentFeatureGlobalMv)) {
664     bp.is_inter = true;
665     return;
666   }
667   int context = 0;
668   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
669     context = (block.IsTopIntra() && block.IsLeftIntra())
670                   ? 3
671                   : static_cast<int>(block.IsTopIntra() || block.IsLeftIntra());
672   } else if (block.top_available[kPlaneY] || block.left_available[kPlaneY]) {
673     context = 2 * static_cast<int>(block.top_available[kPlaneY]
674                                        ? block.IsTopIntra()
675                                        : block.IsLeftIntra());
676   }
677   bp.is_inter =
678       reader_.ReadSymbol(symbol_decoder_context_.is_inter_cdf[context]);
679 }
680 
ReadIntraBlockModeInfo(const Block & block,bool intra_y_mode)681 bool Tile::ReadIntraBlockModeInfo(const Block& block, bool intra_y_mode) {
682   BlockParameters& bp = *block.bp;
683   bp.reference_frame[0] = kReferenceFrameIntra;
684   bp.reference_frame[1] = kReferenceFrameNone;
685   ReadPredictionModeY(block, intra_y_mode);
686   ReadIntraAngleInfo(block, kPlaneTypeY);
687   if (block.HasChroma()) {
688     ReadPredictionModeUV(block);
689     if (bp.uv_mode == kPredictionModeChromaFromLuma) {
690       ReadCflAlpha(block);
691     }
692     ReadIntraAngleInfo(block, kPlaneTypeUV);
693   }
694   ReadPaletteModeInfo(block);
695   ReadFilterIntraModeInfo(block);
696   return true;
697 }
698 
ReadCompoundReferenceType(const Block & block)699 CompoundReferenceType Tile::ReadCompoundReferenceType(const Block& block) {
700   // compound and inter.
701   const bool top_comp_inter = block.top_available[kPlaneY] &&
702                               !block.IsTopIntra() && !block.IsTopSingle();
703   const bool left_comp_inter = block.left_available[kPlaneY] &&
704                                !block.IsLeftIntra() && !block.IsLeftSingle();
705   // unidirectional compound.
706   const bool top_uni_comp =
707       top_comp_inter && IsSameDirectionReferencePair(block.TopReference(0),
708                                                      block.TopReference(1));
709   const bool left_uni_comp =
710       left_comp_inter && IsSameDirectionReferencePair(block.LeftReference(0),
711                                                       block.LeftReference(1));
712   int context;
713   if (block.top_available[kPlaneY] && !block.IsTopIntra() &&
714       block.left_available[kPlaneY] && !block.IsLeftIntra()) {
715     const int same_direction = static_cast<int>(IsSameDirectionReferencePair(
716         block.TopReference(0), block.LeftReference(0)));
717     if (!top_comp_inter && !left_comp_inter) {
718       context = 1 + MultiplyBy2(same_direction);
719     } else if (!top_comp_inter) {
720       context = left_uni_comp ? 3 + same_direction : 1;
721     } else if (!left_comp_inter) {
722       context = top_uni_comp ? 3 + same_direction : 1;
723     } else {
724       if (!top_uni_comp && !left_uni_comp) {
725         context = 0;
726       } else if (!top_uni_comp || !left_uni_comp) {
727         context = 2;
728       } else {
729         context = 3 + static_cast<int>(
730                           (block.TopReference(0) == kReferenceFrameBackward) ==
731                           (block.LeftReference(0) == kReferenceFrameBackward));
732       }
733     }
734   } else if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
735     if (top_comp_inter) {
736       context = 1 + MultiplyBy2(static_cast<int>(top_uni_comp));
737     } else if (left_comp_inter) {
738       context = 1 + MultiplyBy2(static_cast<int>(left_uni_comp));
739     } else {
740       context = 2;
741     }
742   } else if (top_comp_inter) {
743     context = MultiplyBy4(static_cast<int>(top_uni_comp));
744   } else if (left_comp_inter) {
745     context = MultiplyBy4(static_cast<int>(left_uni_comp));
746   } else {
747     context = 2;
748   }
749   return static_cast<CompoundReferenceType>(reader_.ReadSymbol(
750       symbol_decoder_context_.compound_reference_type_cdf[context]));
751 }
752 
753 template <bool is_single, bool is_backward, int index>
GetReferenceCdf(const Block & block,CompoundReferenceType type)754 uint16_t* Tile::GetReferenceCdf(
755     const Block& block,
756     CompoundReferenceType type /*= kNumCompoundReferenceTypes*/) {
757   int context = 0;
758   if ((type == kCompoundReferenceUnidirectional && index == 0) ||
759       (is_single && index == 1)) {
760     // uni_comp_ref and single_ref_p1.
761     context =
762         GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameGolden,
763                             kReferenceFrameBackward, kReferenceFrameAlternate);
764   } else if (type == kCompoundReferenceUnidirectional && index == 1) {
765     // uni_comp_ref_p1.
766     context =
767         GetReferenceContext(block, kReferenceFrameLast2, kReferenceFrameLast2,
768                             kReferenceFrameLast3, kReferenceFrameGolden);
769   } else if ((type == kCompoundReferenceUnidirectional && index == 2) ||
770              (type == kCompoundReferenceBidirectional && index == 2) ||
771              (is_single && index == 5)) {
772     // uni_comp_ref_p2, comp_ref_p2 and single_ref_p5.
773     context =
774         GetReferenceContext(block, kReferenceFrameLast3, kReferenceFrameLast3,
775                             kReferenceFrameGolden, kReferenceFrameGolden);
776   } else if ((type == kCompoundReferenceBidirectional && index == 0) ||
777              (is_single && index == 3)) {
778     // comp_ref and single_ref_p3.
779     context =
780         GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameLast2,
781                             kReferenceFrameLast3, kReferenceFrameGolden);
782   } else if ((type == kCompoundReferenceBidirectional && index == 1) ||
783              (is_single && index == 4)) {
784     // comp_ref_p1 and single_ref_p4.
785     context =
786         GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameLast,
787                             kReferenceFrameLast2, kReferenceFrameLast2);
788   } else if ((is_single && index == 2) || (is_backward && index == 0)) {
789     // single_ref_p2 and comp_bwdref.
790     context = GetReferenceContext(
791         block, kReferenceFrameBackward, kReferenceFrameAlternate2,
792         kReferenceFrameAlternate, kReferenceFrameAlternate);
793   } else if ((is_single && index == 6) || (is_backward && index == 1)) {
794     // single_ref_p6 and comp_bwdref_p1.
795     context = GetReferenceContext(
796         block, kReferenceFrameBackward, kReferenceFrameBackward,
797         kReferenceFrameAlternate2, kReferenceFrameAlternate2);
798   }
799   if (is_single) {
800     // The index parameter for single references is offset by one since the spec
801     // uses 1-based index for these elements.
802     return symbol_decoder_context_.single_reference_cdf[context][index - 1];
803   }
804   if (is_backward) {
805     return symbol_decoder_context_
806         .compound_backward_reference_cdf[context][index];
807   }
808   return symbol_decoder_context_.compound_reference_cdf[type][context][index];
809 }
810 
ReadReferenceFrames(const Block & block)811 void Tile::ReadReferenceFrames(const Block& block) {
812   BlockParameters& bp = *block.bp;
813   if (bp.skip_mode) {
814     bp.reference_frame[0] = frame_header_.skip_mode_frame[0];
815     bp.reference_frame[1] = frame_header_.skip_mode_frame[1];
816     return;
817   }
818   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
819                                                kSegmentFeatureReferenceFrame)) {
820     bp.reference_frame[0] = static_cast<ReferenceFrameType>(
821         frame_header_.segmentation
822             .feature_data[bp.segment_id][kSegmentFeatureReferenceFrame]);
823     bp.reference_frame[1] = kReferenceFrameNone;
824     return;
825   }
826   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
827                                                kSegmentFeatureSkip) ||
828       frame_header_.segmentation.FeatureActive(bp.segment_id,
829                                                kSegmentFeatureGlobalMv)) {
830     bp.reference_frame[0] = kReferenceFrameLast;
831     bp.reference_frame[1] = kReferenceFrameNone;
832     return;
833   }
834   const bool use_compound_reference =
835       frame_header_.reference_mode_select &&
836       std::min(block.width4x4, block.height4x4) >= 2 &&
837       reader_.ReadSymbol(symbol_decoder_context_.use_compound_reference_cdf
838                              [GetUseCompoundReferenceContext(block)]);
839   if (use_compound_reference) {
840     CompoundReferenceType reference_type = ReadCompoundReferenceType(block);
841     if (reference_type == kCompoundReferenceUnidirectional) {
842       // uni_comp_ref.
843       if (reader_.ReadSymbol(
844               GetReferenceCdf<false, false, 0>(block, reference_type))) {
845         bp.reference_frame[0] = kReferenceFrameBackward;
846         bp.reference_frame[1] = kReferenceFrameAlternate;
847         return;
848       }
849       // uni_comp_ref_p1.
850       if (!reader_.ReadSymbol(
851               GetReferenceCdf<false, false, 1>(block, reference_type))) {
852         bp.reference_frame[0] = kReferenceFrameLast;
853         bp.reference_frame[1] = kReferenceFrameLast2;
854         return;
855       }
856       // uni_comp_ref_p2.
857       if (reader_.ReadSymbol(
858               GetReferenceCdf<false, false, 2>(block, reference_type))) {
859         bp.reference_frame[0] = kReferenceFrameLast;
860         bp.reference_frame[1] = kReferenceFrameGolden;
861         return;
862       }
863       bp.reference_frame[0] = kReferenceFrameLast;
864       bp.reference_frame[1] = kReferenceFrameLast3;
865       return;
866     }
867     assert(reference_type == kCompoundReferenceBidirectional);
868     // comp_ref.
869     if (reader_.ReadSymbol(
870             GetReferenceCdf<false, false, 0>(block, reference_type))) {
871       // comp_ref_p2.
872       bp.reference_frame[0] =
873           reader_.ReadSymbol(
874               GetReferenceCdf<false, false, 2>(block, reference_type))
875               ? kReferenceFrameGolden
876               : kReferenceFrameLast3;
877     } else {
878       // comp_ref_p1.
879       bp.reference_frame[0] =
880           reader_.ReadSymbol(
881               GetReferenceCdf<false, false, 1>(block, reference_type))
882               ? kReferenceFrameLast2
883               : kReferenceFrameLast;
884     }
885     // comp_bwdref.
886     if (reader_.ReadSymbol(GetReferenceCdf<false, true, 0>(block))) {
887       bp.reference_frame[1] = kReferenceFrameAlternate;
888     } else {
889       // comp_bwdref_p1.
890       bp.reference_frame[1] =
891           reader_.ReadSymbol(GetReferenceCdf<false, true, 1>(block))
892               ? kReferenceFrameAlternate2
893               : kReferenceFrameBackward;
894     }
895     return;
896   }
897   assert(!use_compound_reference);
898   bp.reference_frame[1] = kReferenceFrameNone;
899   // single_ref_p1.
900   if (reader_.ReadSymbol(GetReferenceCdf<true, false, 1>(block))) {
901     // single_ref_p2.
902     if (reader_.ReadSymbol(GetReferenceCdf<true, false, 2>(block))) {
903       bp.reference_frame[0] = kReferenceFrameAlternate;
904       return;
905     }
906     // single_ref_p6.
907     bp.reference_frame[0] =
908         reader_.ReadSymbol(GetReferenceCdf<true, false, 6>(block))
909             ? kReferenceFrameAlternate2
910             : kReferenceFrameBackward;
911     return;
912   }
913   // single_ref_p3.
914   if (reader_.ReadSymbol(GetReferenceCdf<true, false, 3>(block))) {
915     // single_ref_p5.
916     bp.reference_frame[0] =
917         reader_.ReadSymbol(GetReferenceCdf<true, false, 5>(block))
918             ? kReferenceFrameGolden
919             : kReferenceFrameLast3;
920     return;
921   }
922   // single_ref_p4.
923   bp.reference_frame[0] =
924       reader_.ReadSymbol(GetReferenceCdf<true, false, 4>(block))
925           ? kReferenceFrameLast2
926           : kReferenceFrameLast;
927 }
928 
ReadInterPredictionModeY(const Block & block,const MvContexts & mode_contexts)929 void Tile::ReadInterPredictionModeY(const Block& block,
930                                     const MvContexts& mode_contexts) {
931   BlockParameters& bp = *block.bp;
932   if (bp.skip_mode) {
933     bp.y_mode = kPredictionModeNearestNearestMv;
934     return;
935   }
936   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
937                                                kSegmentFeatureSkip) ||
938       frame_header_.segmentation.FeatureActive(bp.segment_id,
939                                                kSegmentFeatureGlobalMv)) {
940     bp.y_mode = kPredictionModeGlobalMv;
941     return;
942   }
943   if (bp.reference_frame[1] > kReferenceFrameIntra) {
944     const int idx0 = mode_contexts.reference_mv >> 1;
945     const int idx1 =
946         std::min(mode_contexts.new_mv, kCompoundModeNewMvContexts - 1);
947     const int context = kCompoundModeContextMap[idx0][idx1];
948     const int offset = reader_.ReadSymbol<kNumCompoundInterPredictionModes>(
949         symbol_decoder_context_.compound_prediction_mode_cdf[context]);
950     bp.y_mode =
951         static_cast<PredictionMode>(kPredictionModeNearestNearestMv + offset);
952     return;
953   }
954   // new_mv.
955   if (!reader_.ReadSymbol(
956           symbol_decoder_context_.new_mv_cdf[mode_contexts.new_mv])) {
957     bp.y_mode = kPredictionModeNewMv;
958     return;
959   }
960   // zero_mv.
961   if (!reader_.ReadSymbol(
962           symbol_decoder_context_.zero_mv_cdf[mode_contexts.zero_mv])) {
963     bp.y_mode = kPredictionModeGlobalMv;
964     return;
965   }
966   // ref_mv.
967   bp.y_mode =
968       reader_.ReadSymbol(
969           symbol_decoder_context_.reference_mv_cdf[mode_contexts.reference_mv])
970           ? kPredictionModeNearMv
971           : kPredictionModeNearestMv;
972 }
973 
ReadRefMvIndex(const Block & block)974 void Tile::ReadRefMvIndex(const Block& block) {
975   BlockParameters& bp = *block.bp;
976   PredictionParameters& prediction_parameters =
977       *block.bp->prediction_parameters;
978   prediction_parameters.ref_mv_index = 0;
979   if (bp.y_mode != kPredictionModeNewMv &&
980       bp.y_mode != kPredictionModeNewNewMv &&
981       !kPredictionModeHasNearMvMask.Contains(bp.y_mode)) {
982     return;
983   }
984   const int start =
985       static_cast<int>(kPredictionModeHasNearMvMask.Contains(bp.y_mode));
986   prediction_parameters.ref_mv_index = start;
987   for (int i = start; i < start + 2; ++i) {
988     if (prediction_parameters.ref_mv_count <= i + 1) break;
989     // drl_mode in the spec.
990     const bool ref_mv_index_bit = reader_.ReadSymbol(
991         symbol_decoder_context_.ref_mv_index_cdf[GetRefMvIndexContext(
992             prediction_parameters.nearest_mv_count, i)]);
993     prediction_parameters.ref_mv_index = i + static_cast<int>(ref_mv_index_bit);
994     if (!ref_mv_index_bit) return;
995   }
996 }
997 
ReadInterIntraMode(const Block & block,bool is_compound)998 void Tile::ReadInterIntraMode(const Block& block, bool is_compound) {
999   BlockParameters& bp = *block.bp;
1000   PredictionParameters& prediction_parameters =
1001       *block.bp->prediction_parameters;
1002   prediction_parameters.inter_intra_mode = kNumInterIntraModes;
1003   prediction_parameters.is_wedge_inter_intra = false;
1004   if (bp.skip_mode || !sequence_header_.enable_interintra_compound ||
1005       is_compound || !kIsInterIntraModeAllowedMask.Contains(block.size)) {
1006     return;
1007   }
1008   // kSizeGroup[block.size] is guaranteed to be non-zero because of the block
1009   // size constraint enforced in the above condition.
1010   assert(kSizeGroup[block.size] - 1 >= 0);
1011   if (!reader_.ReadSymbol(
1012           symbol_decoder_context_
1013               .is_inter_intra_cdf[kSizeGroup[block.size] - 1])) {
1014     prediction_parameters.inter_intra_mode = kNumInterIntraModes;
1015     return;
1016   }
1017   prediction_parameters.inter_intra_mode =
1018       static_cast<InterIntraMode>(reader_.ReadSymbol<kNumInterIntraModes>(
1019           symbol_decoder_context_
1020               .inter_intra_mode_cdf[kSizeGroup[block.size] - 1]));
1021   bp.reference_frame[1] = kReferenceFrameIntra;
1022   prediction_parameters.angle_delta[kPlaneTypeY] = 0;
1023   prediction_parameters.angle_delta[kPlaneTypeUV] = 0;
1024   prediction_parameters.use_filter_intra = false;
1025   prediction_parameters.is_wedge_inter_intra = reader_.ReadSymbol(
1026       symbol_decoder_context_.is_wedge_inter_intra_cdf[block.size]);
1027   if (!prediction_parameters.is_wedge_inter_intra) return;
1028   prediction_parameters.wedge_index =
1029       reader_.ReadSymbol<kWedgeIndexSymbolCount>(
1030           symbol_decoder_context_.wedge_index_cdf[block.size]);
1031   prediction_parameters.wedge_sign = 0;
1032 }
1033 
ReadMotionMode(const Block & block,bool is_compound)1034 void Tile::ReadMotionMode(const Block& block, bool is_compound) {
1035   BlockParameters& bp = *block.bp;
1036   PredictionParameters& prediction_parameters =
1037       *block.bp->prediction_parameters;
1038   const auto global_motion_type =
1039       frame_header_.global_motion[bp.reference_frame[0]].type;
1040   if (bp.skip_mode || !frame_header_.is_motion_mode_switchable ||
1041       IsBlockDimension4(block.size) ||
1042       (frame_header_.force_integer_mv == 0 &&
1043        (bp.y_mode == kPredictionModeGlobalMv ||
1044         bp.y_mode == kPredictionModeGlobalGlobalMv) &&
1045        global_motion_type > kGlobalMotionTransformationTypeTranslation) ||
1046       is_compound || bp.reference_frame[1] == kReferenceFrameIntra ||
1047       !block.HasOverlappableCandidates()) {
1048     prediction_parameters.motion_mode = kMotionModeSimple;
1049     return;
1050   }
1051   prediction_parameters.num_warp_samples = 0;
1052   int num_samples_scanned = 0;
1053   memset(prediction_parameters.warp_estimate_candidates, 0,
1054          sizeof(prediction_parameters.warp_estimate_candidates));
1055   FindWarpSamples(block, &prediction_parameters.num_warp_samples,
1056                   &num_samples_scanned,
1057                   prediction_parameters.warp_estimate_candidates);
1058   if (frame_header_.force_integer_mv != 0 ||
1059       prediction_parameters.num_warp_samples == 0 ||
1060       !frame_header_.allow_warped_motion || IsScaled(bp.reference_frame[0])) {
1061     prediction_parameters.motion_mode =
1062         reader_.ReadSymbol(symbol_decoder_context_.use_obmc_cdf[block.size])
1063             ? kMotionModeObmc
1064             : kMotionModeSimple;
1065     return;
1066   }
1067   prediction_parameters.motion_mode =
1068       static_cast<MotionMode>(reader_.ReadSymbol<kNumMotionModes>(
1069           symbol_decoder_context_.motion_mode_cdf[block.size]));
1070 }
1071 
GetIsExplicitCompoundTypeCdf(const Block & block)1072 uint16_t* Tile::GetIsExplicitCompoundTypeCdf(const Block& block) {
1073   int context = 0;
1074   if (block.top_available[kPlaneY]) {
1075     if (!block.IsTopSingle()) {
1076       context += static_cast<int>(block.bp_top->is_explicit_compound_type);
1077     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
1078       context += 3;
1079     }
1080   }
1081   if (block.left_available[kPlaneY]) {
1082     if (!block.IsLeftSingle()) {
1083       context += static_cast<int>(block.bp_left->is_explicit_compound_type);
1084     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
1085       context += 3;
1086     }
1087   }
1088   return symbol_decoder_context_.is_explicit_compound_type_cdf[std::min(
1089       context, kIsExplicitCompoundTypeContexts - 1)];
1090 }
1091 
GetIsCompoundTypeAverageCdf(const Block & block)1092 uint16_t* Tile::GetIsCompoundTypeAverageCdf(const Block& block) {
1093   const BlockParameters& bp = *block.bp;
1094   const ReferenceInfo& reference_info = *current_frame_.reference_info();
1095   const int forward =
1096       std::abs(reference_info.relative_distance_from[bp.reference_frame[0]]);
1097   const int backward =
1098       std::abs(reference_info.relative_distance_from[bp.reference_frame[1]]);
1099   int context = (forward == backward) ? 3 : 0;
1100   if (block.top_available[kPlaneY]) {
1101     if (!block.IsTopSingle()) {
1102       context += static_cast<int>(block.bp_top->is_compound_type_average);
1103     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
1104       ++context;
1105     }
1106   }
1107   if (block.left_available[kPlaneY]) {
1108     if (!block.IsLeftSingle()) {
1109       context += static_cast<int>(block.bp_left->is_compound_type_average);
1110     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
1111       ++context;
1112     }
1113   }
1114   return symbol_decoder_context_.is_compound_type_average_cdf[context];
1115 }
1116 
ReadCompoundType(const Block & block,bool is_compound)1117 void Tile::ReadCompoundType(const Block& block, bool is_compound) {
1118   BlockParameters& bp = *block.bp;
1119   bp.is_explicit_compound_type = false;
1120   bp.is_compound_type_average = true;
1121   PredictionParameters& prediction_parameters =
1122       *block.bp->prediction_parameters;
1123   if (bp.skip_mode) {
1124     prediction_parameters.compound_prediction_type =
1125         kCompoundPredictionTypeAverage;
1126     return;
1127   }
1128   if (is_compound) {
1129     if (sequence_header_.enable_masked_compound) {
1130       bp.is_explicit_compound_type =
1131           reader_.ReadSymbol(GetIsExplicitCompoundTypeCdf(block));
1132     }
1133     if (bp.is_explicit_compound_type) {
1134       if (kIsWedgeCompoundModeAllowed.Contains(block.size)) {
1135         // Only kCompoundPredictionTypeWedge and
1136         // kCompoundPredictionTypeDiffWeighted are signaled explicitly.
1137         prediction_parameters.compound_prediction_type =
1138             static_cast<CompoundPredictionType>(reader_.ReadSymbol(
1139                 symbol_decoder_context_.compound_type_cdf[block.size]));
1140       } else {
1141         prediction_parameters.compound_prediction_type =
1142             kCompoundPredictionTypeDiffWeighted;
1143       }
1144     } else {
1145       if (sequence_header_.enable_jnt_comp) {
1146         bp.is_compound_type_average =
1147             reader_.ReadSymbol(GetIsCompoundTypeAverageCdf(block));
1148         prediction_parameters.compound_prediction_type =
1149             bp.is_compound_type_average ? kCompoundPredictionTypeAverage
1150                                         : kCompoundPredictionTypeDistance;
1151       } else {
1152         prediction_parameters.compound_prediction_type =
1153             kCompoundPredictionTypeAverage;
1154         return;
1155       }
1156     }
1157     if (prediction_parameters.compound_prediction_type ==
1158         kCompoundPredictionTypeWedge) {
1159       prediction_parameters.wedge_index =
1160           reader_.ReadSymbol<kWedgeIndexSymbolCount>(
1161               symbol_decoder_context_.wedge_index_cdf[block.size]);
1162       prediction_parameters.wedge_sign = static_cast<int>(reader_.ReadBit());
1163     } else if (prediction_parameters.compound_prediction_type ==
1164                kCompoundPredictionTypeDiffWeighted) {
1165       prediction_parameters.mask_is_inverse =
1166           static_cast<bool>(reader_.ReadBit());
1167     }
1168     return;
1169   }
1170   if (prediction_parameters.inter_intra_mode != kNumInterIntraModes) {
1171     prediction_parameters.compound_prediction_type =
1172         prediction_parameters.is_wedge_inter_intra
1173             ? kCompoundPredictionTypeWedge
1174             : kCompoundPredictionTypeIntra;
1175     return;
1176   }
1177   prediction_parameters.compound_prediction_type =
1178       kCompoundPredictionTypeAverage;
1179 }
1180 
GetInterpolationFilterCdf(const Block & block,int direction)1181 uint16_t* Tile::GetInterpolationFilterCdf(const Block& block, int direction) {
1182   const BlockParameters& bp = *block.bp;
1183   int context = MultiplyBy8(direction) +
1184                 MultiplyBy4(static_cast<int>(bp.reference_frame[1] >
1185                                              kReferenceFrameIntra));
1186   int top_type = kNumExplicitInterpolationFilters;
1187   if (block.top_available[kPlaneY]) {
1188     if (block.bp_top->reference_frame[0] == bp.reference_frame[0] ||
1189         block.bp_top->reference_frame[1] == bp.reference_frame[0]) {
1190       top_type = block.bp_top->interpolation_filter[direction];
1191     }
1192   }
1193   int left_type = kNumExplicitInterpolationFilters;
1194   if (block.left_available[kPlaneY]) {
1195     if (block.bp_left->reference_frame[0] == bp.reference_frame[0] ||
1196         block.bp_left->reference_frame[1] == bp.reference_frame[0]) {
1197       left_type = block.bp_left->interpolation_filter[direction];
1198     }
1199   }
1200   if (left_type == top_type) {
1201     context += left_type;
1202   } else if (left_type == kNumExplicitInterpolationFilters) {
1203     context += top_type;
1204   } else if (top_type == kNumExplicitInterpolationFilters) {
1205     context += left_type;
1206   } else {
1207     context += kNumExplicitInterpolationFilters;
1208   }
1209   return symbol_decoder_context_.interpolation_filter_cdf[context];
1210 }
1211 
ReadInterpolationFilter(const Block & block)1212 void Tile::ReadInterpolationFilter(const Block& block) {
1213   BlockParameters& bp = *block.bp;
1214   if (frame_header_.interpolation_filter != kInterpolationFilterSwitchable) {
1215     static_assert(
1216         sizeof(bp.interpolation_filter) / sizeof(bp.interpolation_filter[0]) ==
1217             2,
1218         "Interpolation filter array size is not 2");
1219     for (auto& interpolation_filter : bp.interpolation_filter) {
1220       interpolation_filter = frame_header_.interpolation_filter;
1221     }
1222     return;
1223   }
1224   bool interpolation_filter_present = true;
1225   if (bp.skip_mode ||
1226       block.bp->prediction_parameters->motion_mode == kMotionModeLocalWarp) {
1227     interpolation_filter_present = false;
1228   } else if (!IsBlockDimension4(block.size) &&
1229              bp.y_mode == kPredictionModeGlobalMv) {
1230     interpolation_filter_present =
1231         frame_header_.global_motion[bp.reference_frame[0]].type ==
1232         kGlobalMotionTransformationTypeTranslation;
1233   } else if (!IsBlockDimension4(block.size) &&
1234              bp.y_mode == kPredictionModeGlobalGlobalMv) {
1235     interpolation_filter_present =
1236         frame_header_.global_motion[bp.reference_frame[0]].type ==
1237             kGlobalMotionTransformationTypeTranslation ||
1238         frame_header_.global_motion[bp.reference_frame[1]].type ==
1239             kGlobalMotionTransformationTypeTranslation;
1240   }
1241   for (int i = 0; i < (sequence_header_.enable_dual_filter ? 2 : 1); ++i) {
1242     bp.interpolation_filter[i] =
1243         interpolation_filter_present
1244             ? static_cast<InterpolationFilter>(
1245                   reader_.ReadSymbol<kNumExplicitInterpolationFilters>(
1246                       GetInterpolationFilterCdf(block, i)))
1247             : kInterpolationFilterEightTap;
1248   }
1249   if (!sequence_header_.enable_dual_filter) {
1250     bp.interpolation_filter[1] = bp.interpolation_filter[0];
1251   }
1252 }
1253 
ReadInterBlockModeInfo(const Block & block)1254 bool Tile::ReadInterBlockModeInfo(const Block& block) {
1255   BlockParameters& bp = *block.bp;
1256   bp.palette_mode_info.size[kPlaneTypeY] = 0;
1257   bp.palette_mode_info.size[kPlaneTypeUV] = 0;
1258   ReadReferenceFrames(block);
1259   const bool is_compound = bp.reference_frame[1] > kReferenceFrameIntra;
1260   MvContexts mode_contexts;
1261   FindMvStack(block, is_compound, &mode_contexts);
1262   ReadInterPredictionModeY(block, mode_contexts);
1263   ReadRefMvIndex(block);
1264   if (!AssignInterMv(block, is_compound)) return false;
1265   ReadInterIntraMode(block, is_compound);
1266   ReadMotionMode(block, is_compound);
1267   ReadCompoundType(block, is_compound);
1268   ReadInterpolationFilter(block);
1269   return true;
1270 }
1271 
DecodeInterModeInfo(const Block & block)1272 bool Tile::DecodeInterModeInfo(const Block& block) {
1273   BlockParameters& bp = *block.bp;
1274   block.bp->prediction_parameters->use_intra_block_copy = false;
1275   bp.skip = false;
1276   if (!ReadInterSegmentId(block, /*pre_skip=*/true)) return false;
1277   ReadSkipMode(block);
1278   if (bp.skip_mode) {
1279     bp.skip = true;
1280   } else {
1281     ReadSkip(block);
1282   }
1283   if (!frame_header_.segmentation.segment_id_pre_skip &&
1284       !ReadInterSegmentId(block, /*pre_skip=*/false)) {
1285     return false;
1286   }
1287   ReadCdef(block);
1288   if (read_deltas_) {
1289     ReadQuantizerIndexDelta(block);
1290     ReadLoopFilterDelta(block);
1291     read_deltas_ = false;
1292   }
1293   ReadIsInter(block);
1294   return bp.is_inter ? ReadInterBlockModeInfo(block)
1295                      : ReadIntraBlockModeInfo(block, /*intra_y_mode=*/false);
1296 }
1297 
DecodeModeInfo(const Block & block)1298 bool Tile::DecodeModeInfo(const Block& block) {
1299   return IsIntraFrame(frame_header_.frame_type) ? DecodeIntraModeInfo(block)
1300                                                 : DecodeInterModeInfo(block);
1301 }
1302 
1303 }  // namespace libgav1
1304