1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <brillo/streams/stream_utils.h>
6
7 #include <algorithm>
8 #include <limits>
9 #include <memory>
10 #include <utility>
11 #include <vector>
12
13 #include <base/bind.h>
14 #include <brillo/message_loops/message_loop.h>
15 #include <brillo/streams/stream_errors.h>
16
17 namespace brillo {
18 namespace stream_utils {
19
20 namespace {
21
22 // Status of asynchronous CopyData operation.
23 struct CopyDataState {
24 brillo::StreamPtr in_stream;
25 brillo::StreamPtr out_stream;
26 std::vector<uint8_t> buffer;
27 uint64_t remaining_to_copy;
28 uint64_t size_copied;
29 CopyDataSuccessCallback success_callback;
30 CopyDataErrorCallback error_callback;
31 };
32
33 // Async CopyData I/O error callback.
OnCopyDataError(const std::shared_ptr<CopyDataState> & state,const brillo::Error * error)34 void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
35 const brillo::Error* error) {
36 state->error_callback.Run(std::move(state->in_stream),
37 std::move(state->out_stream), error);
38 }
39
40 // Forward declaration.
41 void PerformRead(const std::shared_ptr<CopyDataState>& state);
42
43 // Callback from read operation for CopyData. Writes the read data to the output
44 // stream and invokes PerformRead when done to restart the copy cycle.
PerformWrite(const std::shared_ptr<CopyDataState> & state,size_t size)45 void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
46 if (size == 0) {
47 state->success_callback.Run(std::move(state->in_stream),
48 std::move(state->out_stream),
49 state->size_copied);
50 return;
51 }
52 state->size_copied += size;
53 CHECK_GE(state->remaining_to_copy, size);
54 state->remaining_to_copy -= size;
55
56 brillo::ErrorPtr error;
57 bool success = state->out_stream->WriteAllAsync(
58 state->buffer.data(), size, base::Bind(&PerformRead, state),
59 base::Bind(&OnCopyDataError, state), &error);
60
61 if (!success)
62 OnCopyDataError(state, error.get());
63 }
64
65 // Performs the read part of asynchronous CopyData operation. Reads the data
66 // from input stream and invokes PerformWrite when done to write the data to
67 // the output stream.
PerformRead(const std::shared_ptr<CopyDataState> & state)68 void PerformRead(const std::shared_ptr<CopyDataState>& state) {
69 brillo::ErrorPtr error;
70 const uint64_t buffer_size = state->buffer.size();
71 // |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
72 // also not overflow size_t, so the static_cast below is safe.
73 size_t size_to_read =
74 static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
75 if (size_to_read == 0)
76 return PerformWrite(state, 0); // Nothing more to read. Finish operation.
77 bool success = state->in_stream->ReadAsync(
78 state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
79 base::Bind(OnCopyDataError, state), &error);
80
81 if (!success)
82 OnCopyDataError(state, error.get());
83 }
84
85 } // anonymous namespace
86
ErrorStreamClosed(const base::Location & location,ErrorPtr * error)87 bool ErrorStreamClosed(const base::Location& location,
88 ErrorPtr* error) {
89 Error::AddTo(error,
90 location,
91 errors::stream::kDomain,
92 errors::stream::kStreamClosed,
93 "Stream is closed");
94 return false;
95 }
96
ErrorOperationNotSupported(const base::Location & location,ErrorPtr * error)97 bool ErrorOperationNotSupported(const base::Location& location,
98 ErrorPtr* error) {
99 Error::AddTo(error,
100 location,
101 errors::stream::kDomain,
102 errors::stream::kOperationNotSupported,
103 "Stream operation not supported");
104 return false;
105 }
106
ErrorReadPastEndOfStream(const base::Location & location,ErrorPtr * error)107 bool ErrorReadPastEndOfStream(const base::Location& location,
108 ErrorPtr* error) {
109 Error::AddTo(error,
110 location,
111 errors::stream::kDomain,
112 errors::stream::kPartialData,
113 "Reading past the end of stream");
114 return false;
115 }
116
ErrorOperationTimeout(const base::Location & location,ErrorPtr * error)117 bool ErrorOperationTimeout(const base::Location& location,
118 ErrorPtr* error) {
119 Error::AddTo(error,
120 location,
121 errors::stream::kDomain,
122 errors::stream::kTimeout,
123 "Operation timed out");
124 return false;
125 }
126
CheckInt64Overflow(const base::Location & location,uint64_t position,int64_t offset,ErrorPtr * error)127 bool CheckInt64Overflow(const base::Location& location,
128 uint64_t position,
129 int64_t offset,
130 ErrorPtr* error) {
131 if (offset < 0) {
132 // Subtracting the offset. Make sure we do not underflow.
133 uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
134 if (position >= unsigned_offset)
135 return true;
136 } else {
137 // Adding the offset. Make sure we do not overflow unsigned 64 bits first.
138 if (position <= std::numeric_limits<uint64_t>::max() - offset) {
139 // We definitely will not overflow the unsigned 64 bit integer.
140 // Now check that we end up within the limits of signed 64 bit integer.
141 uint64_t new_position = position + offset;
142 uint64_t max = std::numeric_limits<int64_t>::max();
143 if (new_position <= max)
144 return true;
145 }
146 }
147 Error::AddTo(error,
148 location,
149 errors::stream::kDomain,
150 errors::stream::kInvalidParameter,
151 "The stream offset value is out of range");
152 return false;
153 }
154
CalculateStreamPosition(const base::Location & location,int64_t offset,Stream::Whence whence,uint64_t current_position,uint64_t stream_size,uint64_t * new_position,ErrorPtr * error)155 bool CalculateStreamPosition(const base::Location& location,
156 int64_t offset,
157 Stream::Whence whence,
158 uint64_t current_position,
159 uint64_t stream_size,
160 uint64_t* new_position,
161 ErrorPtr* error) {
162 uint64_t pos = 0;
163 switch (whence) {
164 case Stream::Whence::FROM_BEGIN:
165 pos = 0;
166 break;
167
168 case Stream::Whence::FROM_CURRENT:
169 pos = current_position;
170 break;
171
172 case Stream::Whence::FROM_END:
173 pos = stream_size;
174 break;
175
176 default:
177 Error::AddTo(error,
178 location,
179 errors::stream::kDomain,
180 errors::stream::kInvalidParameter,
181 "Invalid stream position whence");
182 return false;
183 }
184
185 if (!CheckInt64Overflow(location, pos, offset, error))
186 return false;
187
188 *new_position = static_cast<uint64_t>(pos + offset);
189 return true;
190 }
191
CopyData(StreamPtr in_stream,StreamPtr out_stream,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)192 void CopyData(StreamPtr in_stream,
193 StreamPtr out_stream,
194 const CopyDataSuccessCallback& success_callback,
195 const CopyDataErrorCallback& error_callback) {
196 CopyData(std::move(in_stream), std::move(out_stream),
197 std::numeric_limits<uint64_t>::max(), 4096, success_callback,
198 error_callback);
199 }
200
CopyData(StreamPtr in_stream,StreamPtr out_stream,uint64_t max_size_to_copy,size_t buffer_size,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)201 void CopyData(StreamPtr in_stream,
202 StreamPtr out_stream,
203 uint64_t max_size_to_copy,
204 size_t buffer_size,
205 const CopyDataSuccessCallback& success_callback,
206 const CopyDataErrorCallback& error_callback) {
207 auto state = std::make_shared<CopyDataState>();
208 state->in_stream = std::move(in_stream);
209 state->out_stream = std::move(out_stream);
210 state->buffer.resize(buffer_size);
211 state->remaining_to_copy = max_size_to_copy;
212 state->size_copied = 0;
213 state->success_callback = success_callback;
214 state->error_callback = error_callback;
215 brillo::MessageLoop::current()->PostTask(FROM_HERE,
216 base::BindOnce(&PerformRead, state));
217 }
218
219 } // namespace stream_utils
220 } // namespace brillo
221