1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "google_apis/gcm/base/socket_stream.h"
6
7 #include "base/bind.h"
8 #include "base/callback.h"
9 #include "net/base/io_buffer.h"
10 #include "net/socket/stream_socket.h"
11
12 namespace gcm {
13
14 namespace {
15
16 // TODO(zea): consider having dynamically-sized buffers if this becomes too
17 // expensive.
18 const uint32 kDefaultBufferSize = 8*1024;
19
20 } // namespace
21
SocketInputStream(net::StreamSocket * socket)22 SocketInputStream::SocketInputStream(net::StreamSocket* socket)
23 : socket_(socket),
24 io_buffer_(new net::IOBuffer(kDefaultBufferSize)),
25 read_buffer_(new net::DrainableIOBuffer(io_buffer_.get(),
26 kDefaultBufferSize)),
27 next_pos_(0),
28 last_error_(net::OK),
29 weak_ptr_factory_(this) {
30 DCHECK(socket->IsConnected());
31 }
32
~SocketInputStream()33 SocketInputStream::~SocketInputStream() {
34 }
35
Next(const void ** data,int * size)36 bool SocketInputStream::Next(const void** data, int* size) {
37 if (GetState() != EMPTY && GetState() != READY) {
38 NOTREACHED() << "Invalid input stream read attempt.";
39 return false;
40 }
41
42 if (GetState() == EMPTY) {
43 DVLOG(1) << "No unread data remaining, ending read.";
44 return false;
45 }
46
47 DCHECK_EQ(GetState(), READY)
48 << " Input stream must have pending data before reading.";
49 DCHECK_LT(next_pos_, read_buffer_->BytesConsumed());
50 *data = io_buffer_->data() + next_pos_;
51 *size = UnreadByteCount();
52 next_pos_ = read_buffer_->BytesConsumed();
53 DVLOG(1) << "Consuming " << *size << " bytes in input buffer.";
54 return true;
55 }
56
BackUp(int count)57 void SocketInputStream::BackUp(int count) {
58 DCHECK(GetState() == READY || GetState() == EMPTY);
59 DCHECK_GT(count, 0);
60 DCHECK_LE(count, next_pos_);
61
62 next_pos_ -= count;
63 DVLOG(1) << "Backing up " << count << " bytes in input buffer. "
64 << "Current position now at " << next_pos_
65 << " of " << read_buffer_->BytesConsumed();
66 }
67
Skip(int count)68 bool SocketInputStream::Skip(int count) {
69 NOTIMPLEMENTED();
70 return false;
71 }
72
ByteCount() const73 int64 SocketInputStream::ByteCount() const {
74 DCHECK_NE(GetState(), CLOSED);
75 DCHECK_NE(GetState(), READING);
76 return next_pos_;
77 }
78
UnreadByteCount() const79 size_t SocketInputStream::UnreadByteCount() const {
80 DCHECK_NE(GetState(), CLOSED);
81 DCHECK_NE(GetState(), READING);
82 return read_buffer_->BytesConsumed() - next_pos_;
83 }
84
Refresh(const base::Closure & callback,int byte_limit)85 net::Error SocketInputStream::Refresh(const base::Closure& callback,
86 int byte_limit) {
87 DCHECK_NE(GetState(), CLOSED);
88 DCHECK_NE(GetState(), READING);
89 DCHECK_GT(byte_limit, 0);
90
91 if (byte_limit > read_buffer_->BytesRemaining()) {
92 LOG(ERROR) << "Out of buffer space, closing input stream.";
93 CloseStream(net::ERR_FILE_TOO_BIG, base::Closure());
94 return net::OK;
95 }
96
97 if (!socket_->IsConnected()) {
98 LOG(ERROR) << "Socket was disconnected, closing input stream";
99 CloseStream(net::ERR_CONNECTION_CLOSED, base::Closure());
100 return net::OK;
101 }
102
103 DVLOG(1) << "Refreshing input stream, limit of " << byte_limit << " bytes.";
104 int result = socket_->Read(
105 read_buffer_,
106 byte_limit,
107 base::Bind(&SocketInputStream::RefreshCompletionCallback,
108 weak_ptr_factory_.GetWeakPtr(),
109 callback));
110 DVLOG(1) << "Read returned " << result;
111 if (result == net::ERR_IO_PENDING) {
112 last_error_ = net::ERR_IO_PENDING;
113 return net::ERR_IO_PENDING;
114 }
115
116 RefreshCompletionCallback(base::Closure(), result);
117 return net::OK;
118 }
119
RebuildBuffer()120 void SocketInputStream::RebuildBuffer() {
121 DVLOG(1) << "Rebuilding input stream, consumed "
122 << next_pos_ << " bytes.";
123 DCHECK_NE(GetState(), READING);
124 DCHECK_NE(GetState(), CLOSED);
125
126 int unread_data_size = 0;
127 const void* unread_data_ptr = NULL;
128 Next(&unread_data_ptr, &unread_data_size);
129 ResetInternal();
130
131 if (unread_data_ptr != io_buffer_->data()) {
132 DVLOG(1) << "Have " << unread_data_size
133 << " unread bytes remaining, shifting.";
134 // Move any remaining unread data to the start of the buffer;
135 std::memmove(io_buffer_->data(), unread_data_ptr, unread_data_size);
136 } else {
137 DVLOG(1) << "Have " << unread_data_size << " unread bytes remaining.";
138 }
139 read_buffer_->DidConsume(unread_data_size);
140 }
141
last_error() const142 net::Error SocketInputStream::last_error() const {
143 return last_error_;
144 }
145
GetState() const146 SocketInputStream::State SocketInputStream::GetState() const {
147 if (last_error_ < net::ERR_IO_PENDING)
148 return CLOSED;
149
150 if (last_error_ == net::ERR_IO_PENDING)
151 return READING;
152
153 DCHECK_EQ(last_error_, net::OK);
154 if (read_buffer_->BytesConsumed() == next_pos_)
155 return EMPTY;
156
157 return READY;
158 }
159
RefreshCompletionCallback(const base::Closure & callback,int result)160 void SocketInputStream::RefreshCompletionCallback(
161 const base::Closure& callback, int result) {
162 // If an error occurred before the completion callback could complete, ignore
163 // the result.
164 if (GetState() == CLOSED)
165 return;
166
167 // Result == 0 implies EOF, which is treated as an error.
168 if (result == 0)
169 result = net::ERR_CONNECTION_CLOSED;
170
171 DCHECK_NE(result, net::ERR_IO_PENDING);
172
173 if (result < net::OK) {
174 DVLOG(1) << "Failed to refresh socket: " << result;
175 CloseStream(static_cast<net::Error>(result), callback);
176 return;
177 }
178
179 DCHECK_GT(result, 0);
180 last_error_ = net::OK;
181 read_buffer_->DidConsume(result);
182
183 DVLOG(1) << "Refresh complete with " << result << " new bytes. "
184 << "Current position " << next_pos_
185 << " of " << read_buffer_->BytesConsumed() << ".";
186
187 if (!callback.is_null())
188 callback.Run();
189 }
190
ResetInternal()191 void SocketInputStream::ResetInternal() {
192 read_buffer_->SetOffset(0);
193 next_pos_ = 0;
194 last_error_ = net::OK;
195 weak_ptr_factory_.InvalidateWeakPtrs(); // Invalidate any callbacks.
196 }
197
CloseStream(net::Error error,const base::Closure & callback)198 void SocketInputStream::CloseStream(net::Error error,
199 const base::Closure& callback) {
200 DCHECK_LT(error, net::ERR_IO_PENDING);
201 ResetInternal();
202 last_error_ = error;
203 LOG(ERROR) << "Closing stream with result " << error;
204 if (!callback.is_null())
205 callback.Run();
206 }
207
SocketOutputStream(net::StreamSocket * socket)208 SocketOutputStream::SocketOutputStream(net::StreamSocket* socket)
209 : socket_(socket),
210 io_buffer_(new net::IOBuffer(kDefaultBufferSize)),
211 write_buffer_(new net::DrainableIOBuffer(io_buffer_.get(),
212 kDefaultBufferSize)),
213 next_pos_(0),
214 last_error_(net::OK),
215 weak_ptr_factory_(this) {
216 DCHECK(socket->IsConnected());
217 }
218
~SocketOutputStream()219 SocketOutputStream::~SocketOutputStream() {
220 }
221
Next(void ** data,int * size)222 bool SocketOutputStream::Next(void** data, int* size) {
223 DCHECK_NE(GetState(), CLOSED);
224 DCHECK_NE(GetState(), FLUSHING);
225 if (next_pos_ == write_buffer_->size())
226 return false;
227
228 *data = write_buffer_->data() + next_pos_;
229 *size = write_buffer_->size() - next_pos_;
230 next_pos_ = write_buffer_->size();
231 return true;
232 }
233
BackUp(int count)234 void SocketOutputStream::BackUp(int count) {
235 DCHECK_GE(count, 0);
236 if (count > next_pos_)
237 next_pos_ = 0;
238 next_pos_ -= count;
239 DVLOG(1) << "Backing up " << count << " bytes in output buffer. "
240 << next_pos_ << " bytes used.";
241 }
242
ByteCount() const243 int64 SocketOutputStream::ByteCount() const {
244 DCHECK_NE(GetState(), CLOSED);
245 DCHECK_NE(GetState(), FLUSHING);
246 return next_pos_;
247 }
248
Flush(const base::Closure & callback)249 net::Error SocketOutputStream::Flush(const base::Closure& callback) {
250 DCHECK_EQ(GetState(), READY);
251
252 if (!socket_->IsConnected()) {
253 LOG(ERROR) << "Socket was disconnected, closing output stream";
254 last_error_ = net::ERR_CONNECTION_CLOSED;
255 return net::OK;
256 }
257
258 DVLOG(1) << "Flushing " << next_pos_ << " bytes into socket.";
259 int result = socket_->Write(
260 write_buffer_,
261 next_pos_,
262 base::Bind(&SocketOutputStream::FlushCompletionCallback,
263 weak_ptr_factory_.GetWeakPtr(),
264 callback));
265 DVLOG(1) << "Write returned " << result;
266 if (result == net::ERR_IO_PENDING) {
267 last_error_ = net::ERR_IO_PENDING;
268 return net::ERR_IO_PENDING;
269 }
270
271 FlushCompletionCallback(base::Closure(), result);
272 return net::OK;
273 }
274
GetState() const275 SocketOutputStream::State SocketOutputStream::GetState() const{
276 if (last_error_ < net::ERR_IO_PENDING)
277 return CLOSED;
278
279 if (last_error_ == net::ERR_IO_PENDING)
280 return FLUSHING;
281
282 DCHECK_EQ(last_error_, net::OK);
283 if (next_pos_ == 0)
284 return EMPTY;
285
286 return READY;
287 }
288
last_error() const289 net::Error SocketOutputStream::last_error() const {
290 return last_error_;
291 }
292
FlushCompletionCallback(const base::Closure & callback,int result)293 void SocketOutputStream::FlushCompletionCallback(
294 const base::Closure& callback, int result) {
295 // If an error occurred before the completion callback could complete, ignore
296 // the result.
297 if (GetState() == CLOSED)
298 return;
299
300 // Result == 0 implies EOF, which is treated as an error.
301 if (result == 0)
302 result = net::ERR_CONNECTION_CLOSED;
303
304 DCHECK_NE(result, net::ERR_IO_PENDING);
305
306 if (result < net::OK) {
307 LOG(ERROR) << "Failed to flush socket.";
308 last_error_ = static_cast<net::Error>(result);
309 if (!callback.is_null())
310 callback.Run();
311 return;
312 }
313
314 DCHECK_GT(result, net::OK);
315 last_error_ = net::OK;
316
317 if (write_buffer_->BytesConsumed() + result < next_pos_) {
318 DVLOG(1) << "Partial flush complete. Retrying.";
319 // Only a partial write was completed. Flush again to finish the write.
320 write_buffer_->DidConsume(result);
321 Flush(callback);
322 return;
323 }
324
325 DVLOG(1) << "Socket flush complete.";
326 write_buffer_->SetOffset(0);
327 next_pos_ = 0;
328 if (!callback.is_null())
329 callback.Run();
330 }
331
332 } // namespace gcm
333