• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::io;
6 use std::os::fd::AsRawFd;
7 use std::os::fd::OwnedFd;
8 use std::os::fd::RawFd;
9 use std::sync::Arc;
10 
11 use base::add_fd_flags;
12 use base::clone_descriptor;
13 use base::linux::fallocate;
14 use base::linux::FallocateMode;
15 use base::AsRawDescriptor;
16 use base::VolatileSlice;
17 use tokio::io::unix::AsyncFd;
18 
19 use crate::mem::MemRegion;
20 use crate::AsyncError;
21 use crate::AsyncResult;
22 use crate::BackingMemory;
23 
24 #[derive(Debug, thiserror::Error)]
25 pub enum Error {
26     #[error("Failed to copy the FD for the polling context: '{0}'")]
27     DuplicatingFd(base::Error),
28     #[error("Failed to punch hole in file: '{0}'.")]
29     Fallocate(base::Error),
30     #[error("Failed to fdatasync: '{0}'")]
31     Fdatasync(io::Error),
32     #[error("Failed to fsync: '{0}'")]
33     Fsync(io::Error),
34     #[error("Failed to join task: '{0}'")]
35     Join(tokio::task::JoinError),
36     #[error("Cannot wait on file descriptor")]
37     NonWaitable,
38     #[error("Failed to read: '{0}'")]
39     Read(io::Error),
40     #[error("Failed to set nonblocking: '{0}'")]
41     SettingNonBlocking(base::Error),
42     #[error("Tokio Async FD error: '{0}'")]
43     TokioAsyncFd(io::Error),
44     #[error("Failed to write: '{0}'")]
45     Write(io::Error),
46 }
47 
48 impl From<Error> for io::Error {
from(e: Error) -> Self49     fn from(e: Error) -> Self {
50         use Error::*;
51         match e {
52             DuplicatingFd(e) => e.into(),
53             Fallocate(e) => e.into(),
54             Fdatasync(e) => e,
55             Fsync(e) => e,
56             Join(e) => io::Error::new(io::ErrorKind::Other, e),
57             NonWaitable => io::Error::new(io::ErrorKind::Other, e),
58             Read(e) => e,
59             SettingNonBlocking(e) => e.into(),
60             TokioAsyncFd(e) => e,
61             Write(e) => e,
62         }
63     }
64 }
65 
66 enum FdType {
67     Async(AsyncFd<Arc<OwnedFd>>),
68     Blocking(Arc<OwnedFd>),
69 }
70 
71 impl AsRawFd for FdType {
as_raw_fd(&self) -> RawFd72     fn as_raw_fd(&self) -> RawFd {
73         match self {
74             FdType::Async(async_fd) => async_fd.as_raw_fd(),
75             FdType::Blocking(blocking) => blocking.as_raw_fd(),
76         }
77     }
78 }
79 
80 impl From<Error> for AsyncError {
from(e: Error) -> AsyncError81     fn from(e: Error) -> AsyncError {
82         AsyncError::SysVariants(e.into())
83     }
84 }
85 
do_fdatasync(raw: Arc<OwnedFd>) -> io::Result<()>86 fn do_fdatasync(raw: Arc<OwnedFd>) -> io::Result<()> {
87     let fd = raw.as_raw_fd();
88     // SAFETY: we partially own `raw`
89     match unsafe { libc::fdatasync(fd) } {
90         0 => Ok(()),
91         _ => Err(io::Error::last_os_error()),
92     }
93 }
94 
do_fsync(raw: Arc<OwnedFd>) -> io::Result<()>95 fn do_fsync(raw: Arc<OwnedFd>) -> io::Result<()> {
96     let fd = raw.as_raw_fd();
97     // SAFETY: we partially own `raw`
98     match unsafe { libc::fsync(fd) } {
99         0 => Ok(()),
100         _ => Err(io::Error::last_os_error()),
101     }
102 }
103 
do_read_to_mem( raw: Arc<OwnedFd>, file_offset: Option<u64>, io_vecs: &Vec<VolatileSlice>, ) -> io::Result<usize>104 fn do_read_to_mem(
105     raw: Arc<OwnedFd>,
106     file_offset: Option<u64>,
107     io_vecs: &Vec<VolatileSlice>,
108 ) -> io::Result<usize> {
109     let ptr = io_vecs.as_ptr() as *const libc::iovec;
110     let len = io_vecs.len() as i32;
111     let fd = raw.as_raw_fd();
112     let res = match file_offset {
113         // SAFETY: we partially own `raw`, `io_vecs` is validated
114         Some(off) => unsafe { libc::preadv64(fd, ptr, len, off as libc::off64_t) },
115         // SAFETY: we partially own `raw`, `io_vecs` is validated
116         None => unsafe { libc::readv(fd, ptr, len) },
117     };
118     match res {
119         r if r >= 0 => Ok(res as usize),
120         _ => Err(io::Error::last_os_error()),
121     }
122 }
do_read_to_vec( raw: Arc<OwnedFd>, file_offset: Option<u64>, vec: &mut Vec<u8>, ) -> io::Result<usize>123 fn do_read_to_vec(
124     raw: Arc<OwnedFd>,
125     file_offset: Option<u64>,
126     vec: &mut Vec<u8>,
127 ) -> io::Result<usize> {
128     let fd = raw.as_raw_fd();
129     let ptr = vec.as_mut_ptr() as *mut libc::c_void;
130     let res = match file_offset {
131         // SAFETY: we partially own `raw`, `ptr` has space up to vec.len()
132         Some(off) => unsafe { libc::pread64(fd, ptr, vec.len(), off as libc::off64_t) },
133         // SAFETY: we partially own `raw`, `ptr` has space up to vec.len()
134         None => unsafe { libc::read(fd, ptr, vec.len()) },
135     };
136     match res {
137         r if r >= 0 => Ok(res as usize),
138         _ => Err(io::Error::last_os_error()),
139     }
140 }
141 
do_write_from_vec( raw: Arc<OwnedFd>, file_offset: Option<u64>, vec: &Vec<u8>, ) -> io::Result<usize>142 fn do_write_from_vec(
143     raw: Arc<OwnedFd>,
144     file_offset: Option<u64>,
145     vec: &Vec<u8>,
146 ) -> io::Result<usize> {
147     let fd = raw.as_raw_fd();
148     let ptr = vec.as_ptr() as *const libc::c_void;
149     let res = match file_offset {
150         // SAFETY: we partially own `raw`, `ptr` has data up to vec.len()
151         Some(off) => unsafe { libc::pwrite64(fd, ptr, vec.len(), off as libc::off64_t) },
152         // SAFETY: we partially own `raw`, `ptr` has data up to vec.len()
153         None => unsafe { libc::write(fd, ptr, vec.len()) },
154     };
155     match res {
156         r if r >= 0 => Ok(res as usize),
157         _ => Err(io::Error::last_os_error()),
158     }
159 }
160 
do_write_from_mem( raw: Arc<OwnedFd>, file_offset: Option<u64>, io_vecs: &Vec<VolatileSlice>, ) -> io::Result<usize>161 fn do_write_from_mem(
162     raw: Arc<OwnedFd>,
163     file_offset: Option<u64>,
164     io_vecs: &Vec<VolatileSlice>,
165 ) -> io::Result<usize> {
166     let ptr = io_vecs.as_ptr() as *const libc::iovec;
167     let len = io_vecs.len() as i32;
168     let fd = raw.as_raw_fd();
169     let res = match file_offset {
170         // SAFETY: we partially own `raw`, `io_vecs` is validated
171         Some(off) => unsafe { libc::pwritev64(fd, ptr, len, off as libc::off64_t) },
172         // SAFETY: we partially own `raw`, `io_vecs` is validated
173         None => unsafe { libc::writev(fd, ptr, len) },
174     };
175     match res {
176         r if r >= 0 => Ok(res as usize),
177         _ => Err(io::Error::last_os_error()),
178     }
179 }
180 
181 pub struct TokioSource<T> {
182     fd: FdType,
183     inner: T,
184     runtime: tokio::runtime::Handle,
185 }
186 impl<T: AsRawDescriptor> TokioSource<T> {
new(inner: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>, Error>187     pub fn new(inner: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>, Error> {
188         let _guard = runtime.enter(); // Required for AsyncFd
189         let safe_fd = clone_descriptor(&inner).map_err(Error::DuplicatingFd)?;
190         let fd_arc: Arc<OwnedFd> = Arc::new(safe_fd.into());
191         let fd = match AsyncFd::new(fd_arc.clone()) {
192             Ok(async_fd) => {
193                 add_fd_flags(async_fd.get_ref().as_raw_descriptor(), libc::O_NONBLOCK)
194                     .map_err(Error::SettingNonBlocking)?;
195                 FdType::Async(async_fd)
196             }
197             Err(e) if e.kind() == io::ErrorKind::PermissionDenied => FdType::Blocking(fd_arc),
198             Err(e) => return Err(Error::TokioAsyncFd(e)),
199         };
200         Ok(TokioSource { fd, inner, runtime })
201     }
202 
as_source(&self) -> &T203     pub fn as_source(&self) -> &T {
204         &self.inner
205     }
206 
as_source_mut(&mut self) -> &mut T207     pub fn as_source_mut(&mut self) -> &mut T {
208         &mut self.inner
209     }
210 
clone_fd(&self) -> Arc<OwnedFd>211     fn clone_fd(&self) -> Arc<OwnedFd> {
212         match &self.fd {
213             FdType::Async(async_fd) => async_fd.get_ref().clone(),
214             FdType::Blocking(blocking) => blocking.clone(),
215         }
216     }
217 
fdatasync(&self) -> AsyncResult<()>218     pub async fn fdatasync(&self) -> AsyncResult<()> {
219         let fd = self.clone_fd();
220         Ok(self
221             .runtime
222             .spawn_blocking(move || do_fdatasync(fd))
223             .await
224             .map_err(Error::Join)?
225             .map_err(Error::Fdatasync)?)
226     }
227 
fsync(&self) -> AsyncResult<()>228     pub async fn fsync(&self) -> AsyncResult<()> {
229         let fd = self.clone_fd();
230         Ok(self
231             .runtime
232             .spawn_blocking(move || do_fsync(fd))
233             .await
234             .map_err(Error::Join)?
235             .map_err(Error::Fsync)?)
236     }
237 
into_source(self) -> T238     pub fn into_source(self) -> T {
239         self.inner
240     }
241 
read_to_vec( &self, file_offset: Option<u64>, mut vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>242     pub async fn read_to_vec(
243         &self,
244         file_offset: Option<u64>,
245         mut vec: Vec<u8>,
246     ) -> AsyncResult<(usize, Vec<u8>)> {
247         Ok(match &self.fd {
248             FdType::Async(async_fd) => {
249                 let res = async_fd
250                     .async_io(tokio::io::Interest::READABLE, |fd| {
251                         do_read_to_vec(fd.clone(), file_offset, &mut vec)
252                     })
253                     .await
254                     .map_err(AsyncError::Io)?;
255                 (res, vec)
256             }
257             FdType::Blocking(blocking) => {
258                 let fd = blocking.clone();
259                 self.runtime
260                     .spawn_blocking(move || {
261                         let size = do_read_to_vec(fd, file_offset, &mut vec)?;
262                         Ok((size, vec))
263                     })
264                     .await
265                     .map_err(Error::Join)?
266                     .map_err(Error::Read)?
267             }
268         })
269     }
270 
read_to_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>271     pub async fn read_to_mem(
272         &self,
273         file_offset: Option<u64>,
274         mem: Arc<dyn BackingMemory + Send + Sync>,
275         mem_offsets: impl IntoIterator<Item = MemRegion>,
276     ) -> AsyncResult<usize> {
277         let mem_offsets_vec: Vec<MemRegion> = mem_offsets.into_iter().collect();
278         Ok(match &self.fd {
279             FdType::Async(async_fd) => {
280                 let iovecs = mem_offsets_vec
281                     .into_iter()
282                     .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
283                     .collect::<Vec<VolatileSlice>>();
284                 async_fd
285                     .async_io(tokio::io::Interest::READABLE, |fd| {
286                         do_read_to_mem(fd.clone(), file_offset, &iovecs)
287                     })
288                     .await
289                     .map_err(AsyncError::Io)?
290             }
291             FdType::Blocking(blocking) => {
292                 let fd = blocking.clone();
293                 self.runtime
294                     .spawn_blocking(move || {
295                         let iovecs = mem_offsets_vec
296                             .into_iter()
297                             .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
298                             .collect::<Vec<VolatileSlice>>();
299                         do_read_to_mem(fd, file_offset, &iovecs)
300                     })
301                     .await
302                     .map_err(Error::Join)?
303                     .map_err(Error::Read)?
304             }
305         })
306     }
307 
punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()>308     pub async fn punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
309         let fd = self.clone_fd();
310         Ok(self
311             .runtime
312             .spawn_blocking(move || fallocate(&*fd, FallocateMode::PunchHole, file_offset, len))
313             .await
314             .map_err(Error::Join)?
315             .map_err(Error::Fallocate)?)
316     }
317 
wait_readable(&self) -> AsyncResult<()>318     pub async fn wait_readable(&self) -> AsyncResult<()> {
319         match &self.fd {
320             FdType::Async(async_fd) => async_fd
321                 .readable()
322                 .await
323                 .map_err(crate::AsyncError::Io)?
324                 .retain_ready(),
325             FdType::Blocking(_) => return Err(Error::NonWaitable.into()),
326         }
327         Ok(())
328     }
329 
write_from_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>330     pub async fn write_from_mem(
331         &self,
332         file_offset: Option<u64>,
333         mem: Arc<dyn BackingMemory + Send + Sync>,
334         mem_offsets: impl IntoIterator<Item = MemRegion>,
335     ) -> AsyncResult<usize> {
336         let mem_offsets_vec: Vec<MemRegion> = mem_offsets.into_iter().collect();
337         Ok(match &self.fd {
338             FdType::Async(async_fd) => {
339                 let iovecs = mem_offsets_vec
340                     .into_iter()
341                     .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
342                     .collect::<Vec<VolatileSlice>>();
343                 async_fd
344                     .async_io(tokio::io::Interest::WRITABLE, |fd| {
345                         do_write_from_mem(fd.clone(), file_offset, &iovecs)
346                     })
347                     .await
348                     .map_err(AsyncError::Io)?
349             }
350             FdType::Blocking(blocking) => {
351                 let fd = blocking.clone();
352                 self.runtime
353                     .spawn_blocking(move || {
354                         let iovecs = mem_offsets_vec
355                             .into_iter()
356                             .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
357                             .collect::<Vec<VolatileSlice>>();
358                         do_write_from_mem(fd, file_offset, &iovecs.clone())
359                     })
360                     .await
361                     .map_err(Error::Join)?
362                     .map_err(Error::Read)?
363             }
364         })
365     }
366 
write_from_vec( &self, file_offset: Option<u64>, vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>367     pub async fn write_from_vec(
368         &self,
369         file_offset: Option<u64>,
370         vec: Vec<u8>,
371     ) -> AsyncResult<(usize, Vec<u8>)> {
372         Ok(match &self.fd {
373             FdType::Async(async_fd) => {
374                 let res = async_fd
375                     .async_io(tokio::io::Interest::WRITABLE, |fd| {
376                         do_write_from_vec(fd.clone(), file_offset, &vec)
377                     })
378                     .await
379                     .map_err(AsyncError::Io)?;
380                 (res, vec)
381             }
382             FdType::Blocking(blocking) => {
383                 let fd = blocking.clone();
384                 self.runtime
385                     .spawn_blocking(move || {
386                         let size = do_write_from_vec(fd.clone(), file_offset, &vec)?;
387                         Ok((size, vec))
388                     })
389                     .await
390                     .map_err(Error::Join)?
391                     .map_err(Error::Read)?
392             }
393         })
394     }
395 
write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()>396     pub async fn write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
397         let fd = self.clone_fd();
398         Ok(self
399             .runtime
400             .spawn_blocking(move || fallocate(&*fd, FallocateMode::ZeroRange, file_offset, len))
401             .await
402             .map_err(Error::Join)?
403             .map_err(Error::Fallocate)?)
404     }
405 }
406