• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright © 2024 Collabora, Ltd.
2 // SPDX-License-Identifier: MIT
3 
4 use std::io;
5 use std::marker::PhantomPinned;
6 use std::pin::Pin;
7 
8 use crate::bindings;
9 
10 struct MemStreamImpl {
11     stream: bindings::u_memstream,
12     buffer: *mut u8,
13     buffer_size: usize,
14     _pin: PhantomPinned,
15 }
16 
17 /// A Rust memstream abstraction. Useful when interacting with C code that
18 /// expects a FILE* pointer.
19 ///
20 /// The size of the buffer is managed by the C code automatically.
21 pub struct MemStream(Pin<Box<MemStreamImpl>>);
22 
23 impl MemStream {
new() -> io::Result<Self>24     pub fn new() -> io::Result<Self> {
25         let mut stream_impl = Box::pin(MemStreamImpl {
26             stream: unsafe { std::mem::zeroed() },
27             buffer: std::ptr::null_mut(),
28             buffer_size: 0,
29             _pin: PhantomPinned,
30         });
31 
32         unsafe {
33             let stream_impl = stream_impl.as_mut().get_unchecked_mut();
34             if !bindings::u_memstream_open(
35                 &mut stream_impl.stream,
36                 (&mut stream_impl.buffer as *mut *mut u8).cast(),
37                 &mut stream_impl.buffer_size,
38             ) {
39                 return Err(io::Error::last_os_error());
40             }
41             if bindings::u_memstream_flush(&mut stream_impl.stream) != 0 {
42                 return Err(io::Error::last_os_error());
43             }
44         }
45 
46         Ok(Self(stream_impl))
47     }
48 
49     // Safety: caller must ensure that inner is not moved through the returned
50     // reference.
inner_mut(&mut self) -> &mut MemStreamImpl51     unsafe fn inner_mut(&mut self) -> &mut MemStreamImpl {
52         unsafe { self.0.as_mut().get_unchecked_mut() }
53     }
54 
55     /// Flushes the stream so written data appears in the stream
flush(&mut self) -> io::Result<()>56     pub fn flush(&mut self) -> io::Result<()> {
57         unsafe {
58             let stream = self.inner_mut();
59             if bindings::u_memstream_flush(&mut stream.stream) != 0 {
60                 return Err(io::Error::last_os_error());
61             }
62         }
63 
64         Ok(())
65     }
66 
67     /// Resets the MemStream
reset(&mut self) -> io::Result<()>68     pub fn reset(&mut self) -> io::Result<()> {
69         *self = Self::new()?;
70         Ok(())
71     }
72 
73     /// Resets the MemStream and returns its contents
take(&mut self) -> io::Result<Vec<u8>>74     pub fn take(&mut self) -> io::Result<Vec<u8>> {
75         let mut vec = Vec::new();
76         vec.extend_from_slice(self.as_slice()?);
77         self.reset()?;
78         Ok(vec)
79     }
80 
81     /// Resets the MemStream and returns its contents as a UTF-8 string
take_utf8_string_lossy(&mut self) -> io::Result<String>82     pub fn take_utf8_string_lossy(&mut self) -> io::Result<String> {
83         let string = String::from_utf8_lossy(self.as_slice()?).into_owned();
84         self.reset()?;
85         Ok(string)
86     }
87 
88     /// Returns the current position in the stream.
position(&self) -> usize89     pub fn position(&self) -> usize {
90         unsafe { bindings::compiler_rs_ftell(self.c_file()) as usize }
91     }
92 
93     /// Seek to a position relative to the start of the stream.
seek(&mut self, offset: u64) -> io::Result<()>94     pub fn seek(&mut self, offset: u64) -> io::Result<()> {
95         let offset = offset.try_into().map_err(|_| {
96             io::Error::new(io::ErrorKind::InvalidInput, "offset too large")
97         })?;
98 
99         unsafe {
100             if bindings::compiler_rs_fseek(self.c_file(), offset, 0) != 0 {
101                 Err(io::Error::last_os_error())
102             } else {
103                 Ok(())
104             }
105         }
106     }
107 
108     /// Returns the underlying C file.
109     ///
110     /// # Safety
111     ///
112     /// The memstream abstraction assumes that the file is valid throughout its
113     /// lifetime.
c_file(&self) -> *mut bindings::FILE114     pub unsafe fn c_file(&self) -> *mut bindings::FILE {
115         self.0.stream.f
116     }
117 
118     /// Returns a slice view into the memstream
119     ///
120     /// This is only safe with respect to other safe Rust methods.  Even though
121     /// this takes a reference to the stream there is nothing preventing you
122     /// from modifying the stream through the FILE with unsafe C code.
123     ///
124     /// This is conceptually the same as `AsRef`, but it flushes the stream
125     /// first, which means it takes &mut self as a receiver.
as_slice(&mut self) -> io::Result<&[u8]>126     fn as_slice(&mut self) -> io::Result<&[u8]> {
127         // Make sure we have the most up-to-date data before returning a slice.
128         self.flush()?;
129         let pos = self.position();
130 
131         if pos == 0 {
132             Ok(&[])
133         } else {
134             // SAFETY: this does not move the stream and we know that
135             // self.position() cannot exceed the stream size as per the
136             // open_memstream() API.
137             Ok(unsafe { std::slice::from_raw_parts(self.0.buffer, pos) })
138         }
139     }
140 }
141 
142 impl Drop for MemStream {
drop(&mut self)143     fn drop(&mut self) {
144         // SAFETY: this does not move the stream.
145         unsafe {
146             bindings::u_memstream_close(&mut self.inner_mut().stream);
147             bindings::compiler_rs_free(self.0.buffer as *mut std::ffi::c_void);
148         }
149     }
150 }
151 
152 #[test]
test_memstream()153 fn test_memstream() {
154     use std::ffi::CString;
155 
156     let mut s = MemStream::new().unwrap();
157     let test_str = "Test string";
158     let test_c_str = CString::new(test_str).unwrap();
159     let test_bytes = test_c_str.as_bytes();
160 
161     unsafe {
162         bindings::compiler_rs_fwrite(
163             test_bytes.as_ptr().cast(),
164             1,
165             test_bytes.len(),
166             s.c_file(),
167         );
168     }
169     let res = s.take_utf8_string_lossy().unwrap();
170     assert_eq!(res, test_str);
171 }
172