• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <brillo/streams/stream_utils.h>
6 
7 #include <algorithm>
8 #include <limits>
9 #include <memory>
10 #include <utility>
11 #include <vector>
12 
13 #include <base/bind.h>
14 #include <brillo/message_loops/message_loop.h>
15 #include <brillo/streams/stream_errors.h>
16 
17 namespace brillo {
18 namespace stream_utils {
19 
20 namespace {
21 
22 // Status of asynchronous CopyData operation.
23 struct CopyDataState {
24   brillo::StreamPtr in_stream;
25   brillo::StreamPtr out_stream;
26   std::vector<uint8_t> buffer;
27   uint64_t remaining_to_copy;
28   uint64_t size_copied;
29   CopyDataSuccessCallback success_callback;
30   CopyDataErrorCallback error_callback;
31 };
32 
33 // Async CopyData I/O error callback.
OnCopyDataError(const std::shared_ptr<CopyDataState> & state,const brillo::Error * error)34 void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
35                      const brillo::Error* error) {
36   state->error_callback.Run(std::move(state->in_stream),
37                             std::move(state->out_stream), error);
38 }
39 
40 // Forward declaration.
41 void PerformRead(const std::shared_ptr<CopyDataState>& state);
42 
43 // Callback from read operation for CopyData. Writes the read data to the output
44 // stream and invokes PerformRead when done to restart the copy cycle.
PerformWrite(const std::shared_ptr<CopyDataState> & state,size_t size)45 void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
46   if (size == 0) {
47     state->success_callback.Run(std::move(state->in_stream),
48                                 std::move(state->out_stream),
49                                 state->size_copied);
50     return;
51   }
52   state->size_copied += size;
53   CHECK_GE(state->remaining_to_copy, size);
54   state->remaining_to_copy -= size;
55 
56   brillo::ErrorPtr error;
57   bool success = state->out_stream->WriteAllAsync(
58       state->buffer.data(), size, base::Bind(&PerformRead, state),
59       base::Bind(&OnCopyDataError, state), &error);
60 
61   if (!success)
62     OnCopyDataError(state, error.get());
63 }
64 
65 // Performs the read part of asynchronous CopyData operation. Reads the data
66 // from input stream and invokes PerformWrite when done to write the data to
67 // the output stream.
PerformRead(const std::shared_ptr<CopyDataState> & state)68 void PerformRead(const std::shared_ptr<CopyDataState>& state) {
69   brillo::ErrorPtr error;
70   const uint64_t buffer_size = state->buffer.size();
71   // |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
72   // also not overflow size_t, so the static_cast below is safe.
73   size_t size_to_read =
74       static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
75   if (size_to_read == 0)
76     return PerformWrite(state, 0);  // Nothing more to read. Finish operation.
77   bool success = state->in_stream->ReadAsync(
78       state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
79       base::Bind(OnCopyDataError, state), &error);
80 
81   if (!success)
82     OnCopyDataError(state, error.get());
83 }
84 
85 }  // anonymous namespace
86 
ErrorStreamClosed(const base::Location & location,ErrorPtr * error)87 bool ErrorStreamClosed(const base::Location& location,
88                        ErrorPtr* error) {
89   Error::AddTo(error,
90                location,
91                errors::stream::kDomain,
92                errors::stream::kStreamClosed,
93                "Stream is closed");
94   return false;
95 }
96 
ErrorOperationNotSupported(const base::Location & location,ErrorPtr * error)97 bool ErrorOperationNotSupported(const base::Location& location,
98                                 ErrorPtr* error) {
99   Error::AddTo(error,
100                location,
101                errors::stream::kDomain,
102                errors::stream::kOperationNotSupported,
103                "Stream operation not supported");
104   return false;
105 }
106 
ErrorReadPastEndOfStream(const base::Location & location,ErrorPtr * error)107 bool ErrorReadPastEndOfStream(const base::Location& location,
108                               ErrorPtr* error) {
109   Error::AddTo(error,
110                location,
111                errors::stream::kDomain,
112                errors::stream::kPartialData,
113                "Reading past the end of stream");
114   return false;
115 }
116 
ErrorOperationTimeout(const base::Location & location,ErrorPtr * error)117 bool ErrorOperationTimeout(const base::Location& location,
118                            ErrorPtr* error) {
119   Error::AddTo(error,
120                location,
121                errors::stream::kDomain,
122                errors::stream::kTimeout,
123                "Operation timed out");
124   return false;
125 }
126 
CheckInt64Overflow(const base::Location & location,uint64_t position,int64_t offset,ErrorPtr * error)127 bool CheckInt64Overflow(const base::Location& location,
128                         uint64_t position,
129                         int64_t offset,
130                         ErrorPtr* error) {
131   if (offset < 0) {
132     // Subtracting the offset. Make sure we do not underflow.
133     uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
134     if (position >= unsigned_offset)
135       return true;
136   } else {
137     // Adding the offset. Make sure we do not overflow unsigned 64 bits first.
138     if (position <= std::numeric_limits<uint64_t>::max() - offset) {
139       // We definitely will not overflow the unsigned 64 bit integer.
140       // Now check that we end up within the limits of signed 64 bit integer.
141       uint64_t new_position = position + offset;
142       uint64_t max = std::numeric_limits<int64_t>::max();
143       if (new_position <= max)
144         return true;
145     }
146   }
147   Error::AddTo(error,
148                location,
149                errors::stream::kDomain,
150                errors::stream::kInvalidParameter,
151                "The stream offset value is out of range");
152   return false;
153 }
154 
CalculateStreamPosition(const base::Location & location,int64_t offset,Stream::Whence whence,uint64_t current_position,uint64_t stream_size,uint64_t * new_position,ErrorPtr * error)155 bool CalculateStreamPosition(const base::Location& location,
156                              int64_t offset,
157                              Stream::Whence whence,
158                              uint64_t current_position,
159                              uint64_t stream_size,
160                              uint64_t* new_position,
161                              ErrorPtr* error) {
162   uint64_t pos = 0;
163   switch (whence) {
164     case Stream::Whence::FROM_BEGIN:
165       pos = 0;
166       break;
167 
168     case Stream::Whence::FROM_CURRENT:
169       pos = current_position;
170       break;
171 
172     case Stream::Whence::FROM_END:
173       pos = stream_size;
174       break;
175 
176     default:
177       Error::AddTo(error,
178                    location,
179                    errors::stream::kDomain,
180                    errors::stream::kInvalidParameter,
181                    "Invalid stream position whence");
182       return false;
183   }
184 
185   if (!CheckInt64Overflow(location, pos, offset, error))
186     return false;
187 
188   *new_position = static_cast<uint64_t>(pos + offset);
189   return true;
190 }
191 
CopyData(StreamPtr in_stream,StreamPtr out_stream,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)192 void CopyData(StreamPtr in_stream,
193               StreamPtr out_stream,
194               const CopyDataSuccessCallback& success_callback,
195               const CopyDataErrorCallback& error_callback) {
196   CopyData(std::move(in_stream), std::move(out_stream),
197            std::numeric_limits<uint64_t>::max(), 4096, success_callback,
198            error_callback);
199 }
200 
CopyData(StreamPtr in_stream,StreamPtr out_stream,uint64_t max_size_to_copy,size_t buffer_size,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)201 void CopyData(StreamPtr in_stream,
202               StreamPtr out_stream,
203               uint64_t max_size_to_copy,
204               size_t buffer_size,
205               const CopyDataSuccessCallback& success_callback,
206               const CopyDataErrorCallback& error_callback) {
207   auto state = std::make_shared<CopyDataState>();
208   state->in_stream = std::move(in_stream);
209   state->out_stream = std::move(out_stream);
210   state->buffer.resize(buffer_size);
211   state->remaining_to_copy = max_size_to_copy;
212   state->size_copied = 0;
213   state->success_callback = success_callback;
214   state->error_callback = error_callback;
215   brillo::MessageLoop::current()->PostTask(FROM_HERE,
216                                            base::BindOnce(&PerformRead, state));
217 }
218 
219 }  // namespace stream_utils
220 }  // namespace brillo
221