1 // Copyright 2024 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 use std::fmt; 5 use std::io::Write; 6 7 use crate::bitstream_utils::BitWriter; 8 use crate::bitstream_utils::BitWriterError; 9 10 /// Internal wrapper over [`std::io::Write`] for possible emulation prevention 11 struct EmulationPrevention<W: Write> { 12 out: W, 13 prev_bytes: [Option<u8>; 2], 14 15 /// Emulation prevention enabled. 16 ep_enabled: bool, 17 } 18 19 impl<W: Write> EmulationPrevention<W> { new(writer: W, ep_enabled: bool) -> Self20 fn new(writer: W, ep_enabled: bool) -> Self { 21 Self { out: writer, prev_bytes: [None; 2], ep_enabled } 22 } 23 write_byte(&mut self, curr_byte: u8) -> std::io::Result<()>24 fn write_byte(&mut self, curr_byte: u8) -> std::io::Result<()> { 25 if self.prev_bytes[1] == Some(0x00) && self.prev_bytes[0] == Some(0x00) && curr_byte <= 0x03 26 { 27 self.out.write_all(&[0x00, 0x00, 0x03, curr_byte])?; 28 self.prev_bytes = [None; 2]; 29 } else { 30 if let Some(byte) = self.prev_bytes[1] { 31 self.out.write_all(&[byte])?; 32 } 33 34 self.prev_bytes[1] = self.prev_bytes[0]; 35 self.prev_bytes[0] = Some(curr_byte); 36 } 37 38 Ok(()) 39 } 40 41 /// Writes a H.264 NALU header. write_header(&mut self, idc: u8, type_: u8) -> NaluWriterResult<()>42 fn write_header(&mut self, idc: u8, type_: u8) -> NaluWriterResult<()> { 43 self.out.write_all(&[0x00, 0x00, 0x00, 0x01, (idc & 0b11) << 5 | (type_ & 0b11111)])?; 44 45 Ok(()) 46 } 47 has_data_pending(&self) -> bool48 fn has_data_pending(&self) -> bool { 49 self.prev_bytes[0].is_some() || self.prev_bytes[1].is_some() 50 } 51 } 52 53 impl<W: Write> Write for EmulationPrevention<W> { write(&mut self, buf: &[u8]) -> std::io::Result<usize>54 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { 55 if !self.ep_enabled { 56 self.out.write_all(buf)?; 57 return Ok(buf.len()); 58 } 59 60 for byte in buf { 61 self.write_byte(*byte)?; 62 } 63 64 Ok(buf.len()) 65 } 66 flush(&mut self) -> std::io::Result<()>67 fn flush(&mut self) -> std::io::Result<()> { 68 if let Some(byte) = self.prev_bytes[1].take() { 69 self.out.write_all(&[byte])?; 70 } 71 72 if let Some(byte) = self.prev_bytes[0].take() { 73 self.out.write_all(&[byte])?; 74 } 75 76 self.out.flush() 77 } 78 } 79 80 impl<W: Write> Drop for EmulationPrevention<W> { drop(&mut self)81 fn drop(&mut self) { 82 if let Err(e) = self.flush() { 83 log::error!("Unable to flush pending bytes {e:?}"); 84 } 85 } 86 } 87 88 #[derive(Debug)] 89 pub enum NaluWriterError { 90 Overflow, 91 Io(std::io::Error), 92 BitWriterError(BitWriterError), 93 } 94 95 impl fmt::Display for NaluWriterError { fmt(&self, f: &mut fmt::Formatter) -> fmt::Result96 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 97 match self { 98 NaluWriterError::Overflow => write!(f, "value increment caused value overflow"), 99 NaluWriterError::Io(x) => write!(f, "{}", x.to_string()), 100 NaluWriterError::BitWriterError(x) => write!(f, "{}", x.to_string()), 101 } 102 } 103 } 104 105 impl From<std::io::Error> for NaluWriterError { from(err: std::io::Error) -> Self106 fn from(err: std::io::Error) -> Self { 107 NaluWriterError::Io(err) 108 } 109 } 110 111 impl From<BitWriterError> for NaluWriterError { from(err: BitWriterError) -> Self112 fn from(err: BitWriterError) -> Self { 113 NaluWriterError::BitWriterError(err) 114 } 115 } 116 117 pub type NaluWriterResult<T> = std::result::Result<T, NaluWriterError>; 118 119 /// A writer for H.264 bitstream. It is capable of outputing bitstream with 120 /// emulation-prevention. 121 pub struct NaluWriter<W: Write>(BitWriter<EmulationPrevention<W>>); 122 123 impl<W: Write> NaluWriter<W> { new(writer: W, ep_enabled: bool) -> Self124 pub fn new(writer: W, ep_enabled: bool) -> Self { 125 Self(BitWriter::new(EmulationPrevention::new(writer, ep_enabled))) 126 } 127 128 /// Writes fixed bit size integer (up to 32 bit) output with emulation 129 /// prevention if enabled. Corresponds to `f(n)` in H.264 spec. write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize>130 pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> { 131 self.0.write_f(bits, value).map_err(NaluWriterError::BitWriterError) 132 } 133 134 /// An alias to [`Self::write_f`] Corresponds to `n(n)` in H.264 spec. write_u<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize>135 pub fn write_u<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> { 136 self.write_f(bits, value) 137 } 138 139 /// Writes a number in exponential golumb format. write_exp_golumb(&mut self, value: u32) -> NaluWriterResult<()>140 pub fn write_exp_golumb(&mut self, value: u32) -> NaluWriterResult<()> { 141 let value = value.checked_add(1).ok_or(NaluWriterError::Overflow)?; 142 let bits = 32 - value.leading_zeros() as usize; 143 let zeros = bits - 1; 144 145 self.write_f(zeros, 0u32)?; 146 self.write_f(bits, value)?; 147 148 Ok(()) 149 } 150 151 /// Writes a unsigned integer in exponential golumb format. 152 /// Coresponds to `ue(v)` in H.264 spec. write_ue<T: Into<u32>>(&mut self, value: T) -> NaluWriterResult<()>153 pub fn write_ue<T: Into<u32>>(&mut self, value: T) -> NaluWriterResult<()> { 154 let value = value.into(); 155 156 self.write_exp_golumb(value) 157 } 158 159 /// Writes a signed integer in exponential golumb format. 160 /// Coresponds to `se(v)` in H.264 spec. write_se<T: Into<i32>>(&mut self, value: T) -> NaluWriterResult<()>161 pub fn write_se<T: Into<i32>>(&mut self, value: T) -> NaluWriterResult<()> { 162 let value: i32 = value.into(); 163 let abs_value: u32 = value.unsigned_abs(); 164 165 if value <= 0 { 166 self.write_ue(2 * abs_value) 167 } else { 168 self.write_ue(2 * abs_value - 1) 169 } 170 } 171 172 /// Returns `true` if ['Self`] hold data that wasn't written to [`std::io::Write`] has_data_pending(&self) -> bool173 pub fn has_data_pending(&self) -> bool { 174 self.0.has_data_pending() || self.0.inner().has_data_pending() 175 } 176 177 /// Writes a H.264 NALU header. write_header(&mut self, idc: u8, _type: u8) -> NaluWriterResult<()>178 pub fn write_header(&mut self, idc: u8, _type: u8) -> NaluWriterResult<()> { 179 self.0.flush()?; 180 self.0.inner_mut().write_header(idc, _type)?; 181 Ok(()) 182 } 183 184 /// Returns `true` if next bits will be aligned to 8 aligned(&self) -> bool185 pub fn aligned(&self) -> bool { 186 !self.0.has_data_pending() 187 } 188 } 189 190 #[cfg(test)] 191 mod tests { 192 use super::*; 193 use crate::bitstream_utils::BitReader; 194 195 #[test] simple_bits()196 fn simple_bits() { 197 let mut buf = Vec::<u8>::new(); 198 { 199 let mut writer = NaluWriter::new(&mut buf, false); 200 writer.write_f(1, true).unwrap(); 201 writer.write_f(1, false).unwrap(); 202 writer.write_f(1, false).unwrap(); 203 writer.write_f(1, false).unwrap(); 204 writer.write_f(1, true).unwrap(); 205 writer.write_f(1, true).unwrap(); 206 writer.write_f(1, true).unwrap(); 207 writer.write_f(1, true).unwrap(); 208 } 209 assert_eq!(buf, vec![0b10001111u8]); 210 } 211 212 #[test] simple_first_few_ue()213 fn simple_first_few_ue() { 214 fn single_ue(value: u32) -> Vec<u8> { 215 let mut buf = Vec::<u8>::new(); 216 { 217 let mut writer = NaluWriter::new(&mut buf, false); 218 writer.write_ue(value).unwrap(); 219 } 220 buf 221 } 222 223 assert_eq!(single_ue(0), vec![0b10000000u8]); 224 assert_eq!(single_ue(1), vec![0b01000000u8]); 225 assert_eq!(single_ue(2), vec![0b01100000u8]); 226 assert_eq!(single_ue(3), vec![0b00100000u8]); 227 assert_eq!(single_ue(4), vec![0b00101000u8]); 228 assert_eq!(single_ue(5), vec![0b00110000u8]); 229 assert_eq!(single_ue(6), vec![0b00111000u8]); 230 assert_eq!(single_ue(7), vec![0b00010000u8]); 231 assert_eq!(single_ue(8), vec![0b00010010u8]); 232 assert_eq!(single_ue(9), vec![0b00010100u8]); 233 } 234 235 #[test] writer_reader()236 fn writer_reader() { 237 let mut buf = Vec::<u8>::new(); 238 { 239 let mut writer = NaluWriter::new(&mut buf, false); 240 writer.write_ue(10u32).unwrap(); 241 writer.write_se(-42).unwrap(); 242 writer.write_se(3).unwrap(); 243 writer.write_ue(5u32).unwrap(); 244 } 245 246 let mut reader = BitReader::new(&buf, true); 247 248 assert_eq!(reader.read_ue::<u32>().unwrap(), 10); 249 assert_eq!(reader.read_se::<i32>().unwrap(), -42); 250 assert_eq!(reader.read_se::<i32>().unwrap(), 3); 251 assert_eq!(reader.read_ue::<u32>().unwrap(), 5); 252 253 let mut buf = Vec::<u8>::new(); 254 { 255 let mut writer = NaluWriter::new(&mut buf, false); 256 writer.write_se(30).unwrap(); 257 writer.write_ue(100u32).unwrap(); 258 writer.write_se(-402).unwrap(); 259 writer.write_ue(50u32).unwrap(); 260 } 261 262 let mut reader = BitReader::new(&buf, true); 263 264 assert_eq!(reader.read_se::<i32>().unwrap(), 30); 265 assert_eq!(reader.read_ue::<u32>().unwrap(), 100); 266 assert_eq!(reader.read_se::<i32>().unwrap(), -402); 267 assert_eq!(reader.read_ue::<u32>().unwrap(), 50); 268 } 269 270 #[test] writer_emulation_prevention()271 fn writer_emulation_prevention() { 272 fn test(input: &[u8], bitstream: &[u8]) { 273 let mut buf = Vec::<u8>::new(); 274 { 275 let mut writer = NaluWriter::new(&mut buf, true); 276 for byte in input { 277 writer.write_f(8, *byte).unwrap(); 278 } 279 } 280 assert_eq!(buf, bitstream); 281 { 282 let mut reader = BitReader::new(&buf, true); 283 for byte in input { 284 assert_eq!(*byte, reader.read_bits::<u8>(8).unwrap()); 285 } 286 } 287 } 288 289 test(&[0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00]); 290 test(&[0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x01]); 291 test(&[0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x02]); 292 test(&[0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x03]); 293 294 test(&[0x00, 0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00, 0x00]); 295 test(&[0x00, 0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x00, 0x01]); 296 test(&[0x00, 0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x00, 0x02]); 297 test(&[0x00, 0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x00, 0x03]); 298 } 299 } 300