• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 use std::fmt;
5 use std::io::Write;
6 
7 use crate::bitstream_utils::BitWriter;
8 use crate::bitstream_utils::BitWriterError;
9 
10 /// Internal wrapper over [`std::io::Write`] for possible emulation prevention
11 struct EmulationPrevention<W: Write> {
12     out: W,
13     prev_bytes: [Option<u8>; 2],
14 
15     /// Emulation prevention enabled.
16     ep_enabled: bool,
17 }
18 
19 impl<W: Write> EmulationPrevention<W> {
new(writer: W, ep_enabled: bool) -> Self20     fn new(writer: W, ep_enabled: bool) -> Self {
21         Self { out: writer, prev_bytes: [None; 2], ep_enabled }
22     }
23 
write_byte(&mut self, curr_byte: u8) -> std::io::Result<()>24     fn write_byte(&mut self, curr_byte: u8) -> std::io::Result<()> {
25         if self.prev_bytes[1] == Some(0x00) && self.prev_bytes[0] == Some(0x00) && curr_byte <= 0x03
26         {
27             self.out.write_all(&[0x00, 0x00, 0x03, curr_byte])?;
28             self.prev_bytes = [None; 2];
29         } else {
30             if let Some(byte) = self.prev_bytes[1] {
31                 self.out.write_all(&[byte])?;
32             }
33 
34             self.prev_bytes[1] = self.prev_bytes[0];
35             self.prev_bytes[0] = Some(curr_byte);
36         }
37 
38         Ok(())
39     }
40 
41     /// Writes a H.264 NALU header.
write_header(&mut self, idc: u8, type_: u8) -> NaluWriterResult<()>42     fn write_header(&mut self, idc: u8, type_: u8) -> NaluWriterResult<()> {
43         self.out.write_all(&[0x00, 0x00, 0x00, 0x01, (idc & 0b11) << 5 | (type_ & 0b11111)])?;
44 
45         Ok(())
46     }
47 
has_data_pending(&self) -> bool48     fn has_data_pending(&self) -> bool {
49         self.prev_bytes[0].is_some() || self.prev_bytes[1].is_some()
50     }
51 }
52 
53 impl<W: Write> Write for EmulationPrevention<W> {
write(&mut self, buf: &[u8]) -> std::io::Result<usize>54     fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
55         if !self.ep_enabled {
56             self.out.write_all(buf)?;
57             return Ok(buf.len());
58         }
59 
60         for byte in buf {
61             self.write_byte(*byte)?;
62         }
63 
64         Ok(buf.len())
65     }
66 
flush(&mut self) -> std::io::Result<()>67     fn flush(&mut self) -> std::io::Result<()> {
68         if let Some(byte) = self.prev_bytes[1].take() {
69             self.out.write_all(&[byte])?;
70         }
71 
72         if let Some(byte) = self.prev_bytes[0].take() {
73             self.out.write_all(&[byte])?;
74         }
75 
76         self.out.flush()
77     }
78 }
79 
80 impl<W: Write> Drop for EmulationPrevention<W> {
drop(&mut self)81     fn drop(&mut self) {
82         if let Err(e) = self.flush() {
83             log::error!("Unable to flush pending bytes {e:?}");
84         }
85     }
86 }
87 
88 #[derive(Debug)]
89 pub enum NaluWriterError {
90     Overflow,
91     Io(std::io::Error),
92     BitWriterError(BitWriterError),
93 }
94 
95 impl fmt::Display for NaluWriterError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result96     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97         match self {
98             NaluWriterError::Overflow => write!(f, "value increment caused value overflow"),
99             NaluWriterError::Io(x) => write!(f, "{}", x.to_string()),
100             NaluWriterError::BitWriterError(x) => write!(f, "{}", x.to_string()),
101         }
102     }
103 }
104 
105 impl From<std::io::Error> for NaluWriterError {
from(err: std::io::Error) -> Self106     fn from(err: std::io::Error) -> Self {
107         NaluWriterError::Io(err)
108     }
109 }
110 
111 impl From<BitWriterError> for NaluWriterError {
from(err: BitWriterError) -> Self112     fn from(err: BitWriterError) -> Self {
113         NaluWriterError::BitWriterError(err)
114     }
115 }
116 
117 pub type NaluWriterResult<T> = std::result::Result<T, NaluWriterError>;
118 
119 /// A writer for H.264 bitstream. It is capable of outputing bitstream with
120 /// emulation-prevention.
121 pub struct NaluWriter<W: Write>(BitWriter<EmulationPrevention<W>>);
122 
123 impl<W: Write> NaluWriter<W> {
new(writer: W, ep_enabled: bool) -> Self124     pub fn new(writer: W, ep_enabled: bool) -> Self {
125         Self(BitWriter::new(EmulationPrevention::new(writer, ep_enabled)))
126     }
127 
128     /// Writes fixed bit size integer (up to 32 bit) output with emulation
129     /// prevention if enabled. Corresponds to `f(n)` in H.264 spec.
write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize>130     pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
131         self.0.write_f(bits, value).map_err(NaluWriterError::BitWriterError)
132     }
133 
134     /// An alias to [`Self::write_f`] Corresponds to `n(n)` in H.264 spec.
write_u<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize>135     pub fn write_u<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
136         self.write_f(bits, value)
137     }
138 
139     /// Writes a number in exponential golumb format.
write_exp_golumb(&mut self, value: u32) -> NaluWriterResult<()>140     pub fn write_exp_golumb(&mut self, value: u32) -> NaluWriterResult<()> {
141         let value = value.checked_add(1).ok_or(NaluWriterError::Overflow)?;
142         let bits = 32 - value.leading_zeros() as usize;
143         let zeros = bits - 1;
144 
145         self.write_f(zeros, 0u32)?;
146         self.write_f(bits, value)?;
147 
148         Ok(())
149     }
150 
151     /// Writes a unsigned integer in exponential golumb format.
152     /// Coresponds to `ue(v)` in H.264 spec.
write_ue<T: Into<u32>>(&mut self, value: T) -> NaluWriterResult<()>153     pub fn write_ue<T: Into<u32>>(&mut self, value: T) -> NaluWriterResult<()> {
154         let value = value.into();
155 
156         self.write_exp_golumb(value)
157     }
158 
159     /// Writes a signed integer in exponential golumb format.
160     /// Coresponds to `se(v)` in H.264 spec.
write_se<T: Into<i32>>(&mut self, value: T) -> NaluWriterResult<()>161     pub fn write_se<T: Into<i32>>(&mut self, value: T) -> NaluWriterResult<()> {
162         let value: i32 = value.into();
163         let abs_value: u32 = value.unsigned_abs();
164 
165         if value <= 0 {
166             self.write_ue(2 * abs_value)
167         } else {
168             self.write_ue(2 * abs_value - 1)
169         }
170     }
171 
172     /// Returns `true` if ['Self`] hold data that wasn't written to [`std::io::Write`]
has_data_pending(&self) -> bool173     pub fn has_data_pending(&self) -> bool {
174         self.0.has_data_pending() || self.0.inner().has_data_pending()
175     }
176 
177     /// Writes a H.264 NALU header.
write_header(&mut self, idc: u8, _type: u8) -> NaluWriterResult<()>178     pub fn write_header(&mut self, idc: u8, _type: u8) -> NaluWriterResult<()> {
179         self.0.flush()?;
180         self.0.inner_mut().write_header(idc, _type)?;
181         Ok(())
182     }
183 
184     /// Returns `true` if next bits will be aligned to 8
aligned(&self) -> bool185     pub fn aligned(&self) -> bool {
186         !self.0.has_data_pending()
187     }
188 }
189 
190 #[cfg(test)]
191 mod tests {
192     use super::*;
193     use crate::bitstream_utils::BitReader;
194 
195     #[test]
simple_bits()196     fn simple_bits() {
197         let mut buf = Vec::<u8>::new();
198         {
199             let mut writer = NaluWriter::new(&mut buf, false);
200             writer.write_f(1, true).unwrap();
201             writer.write_f(1, false).unwrap();
202             writer.write_f(1, false).unwrap();
203             writer.write_f(1, false).unwrap();
204             writer.write_f(1, true).unwrap();
205             writer.write_f(1, true).unwrap();
206             writer.write_f(1, true).unwrap();
207             writer.write_f(1, true).unwrap();
208         }
209         assert_eq!(buf, vec![0b10001111u8]);
210     }
211 
212     #[test]
simple_first_few_ue()213     fn simple_first_few_ue() {
214         fn single_ue(value: u32) -> Vec<u8> {
215             let mut buf = Vec::<u8>::new();
216             {
217                 let mut writer = NaluWriter::new(&mut buf, false);
218                 writer.write_ue(value).unwrap();
219             }
220             buf
221         }
222 
223         assert_eq!(single_ue(0), vec![0b10000000u8]);
224         assert_eq!(single_ue(1), vec![0b01000000u8]);
225         assert_eq!(single_ue(2), vec![0b01100000u8]);
226         assert_eq!(single_ue(3), vec![0b00100000u8]);
227         assert_eq!(single_ue(4), vec![0b00101000u8]);
228         assert_eq!(single_ue(5), vec![0b00110000u8]);
229         assert_eq!(single_ue(6), vec![0b00111000u8]);
230         assert_eq!(single_ue(7), vec![0b00010000u8]);
231         assert_eq!(single_ue(8), vec![0b00010010u8]);
232         assert_eq!(single_ue(9), vec![0b00010100u8]);
233     }
234 
235     #[test]
writer_reader()236     fn writer_reader() {
237         let mut buf = Vec::<u8>::new();
238         {
239             let mut writer = NaluWriter::new(&mut buf, false);
240             writer.write_ue(10u32).unwrap();
241             writer.write_se(-42).unwrap();
242             writer.write_se(3).unwrap();
243             writer.write_ue(5u32).unwrap();
244         }
245 
246         let mut reader = BitReader::new(&buf, true);
247 
248         assert_eq!(reader.read_ue::<u32>().unwrap(), 10);
249         assert_eq!(reader.read_se::<i32>().unwrap(), -42);
250         assert_eq!(reader.read_se::<i32>().unwrap(), 3);
251         assert_eq!(reader.read_ue::<u32>().unwrap(), 5);
252 
253         let mut buf = Vec::<u8>::new();
254         {
255             let mut writer = NaluWriter::new(&mut buf, false);
256             writer.write_se(30).unwrap();
257             writer.write_ue(100u32).unwrap();
258             writer.write_se(-402).unwrap();
259             writer.write_ue(50u32).unwrap();
260         }
261 
262         let mut reader = BitReader::new(&buf, true);
263 
264         assert_eq!(reader.read_se::<i32>().unwrap(), 30);
265         assert_eq!(reader.read_ue::<u32>().unwrap(), 100);
266         assert_eq!(reader.read_se::<i32>().unwrap(), -402);
267         assert_eq!(reader.read_ue::<u32>().unwrap(), 50);
268     }
269 
270     #[test]
writer_emulation_prevention()271     fn writer_emulation_prevention() {
272         fn test(input: &[u8], bitstream: &[u8]) {
273             let mut buf = Vec::<u8>::new();
274             {
275                 let mut writer = NaluWriter::new(&mut buf, true);
276                 for byte in input {
277                     writer.write_f(8, *byte).unwrap();
278                 }
279             }
280             assert_eq!(buf, bitstream);
281             {
282                 let mut reader = BitReader::new(&buf, true);
283                 for byte in input {
284                     assert_eq!(*byte, reader.read_bits::<u8>(8).unwrap());
285                 }
286             }
287         }
288 
289         test(&[0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00]);
290         test(&[0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x01]);
291         test(&[0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x02]);
292         test(&[0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x03]);
293 
294         test(&[0x00, 0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00, 0x00]);
295         test(&[0x00, 0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x00, 0x01]);
296         test(&[0x00, 0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x00, 0x02]);
297         test(&[0x00, 0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x00, 0x03]);
298     }
299 }
300