1 // Copyright 2020 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/read_buffering_stream_socket.h"
6
7 #include <algorithm>
8
9 #include "base/check_op.h"
10 #include "base/notreached.h"
11 #include "net/base/io_buffer.h"
12
13 namespace net {
14
ReadBufferingStreamSocket(std::unique_ptr<StreamSocket> transport)15 ReadBufferingStreamSocket::ReadBufferingStreamSocket(
16 std::unique_ptr<StreamSocket> transport)
17 : WrappedStreamSocket(std::move(transport)) {}
18
19 ReadBufferingStreamSocket::~ReadBufferingStreamSocket() = default;
20
BufferNextRead(int size)21 void ReadBufferingStreamSocket::BufferNextRead(int size) {
22 DCHECK(!user_read_buf_);
23 read_buffer_ = base::MakeRefCounted<GrowableIOBuffer>();
24 read_buffer_->SetCapacity(size);
25 buffer_full_ = false;
26 }
27
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)28 int ReadBufferingStreamSocket::Read(IOBuffer* buf,
29 int buf_len,
30 CompletionOnceCallback callback) {
31 DCHECK(!user_read_buf_);
32 if (!read_buffer_)
33 return transport_->Read(buf, buf_len, std::move(callback));
34 int rv = ReadIfReady(buf, buf_len, std::move(callback));
35 if (rv == ERR_IO_PENDING) {
36 user_read_buf_ = buf;
37 user_read_buf_len_ = buf_len;
38 }
39 return rv;
40 }
41
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)42 int ReadBufferingStreamSocket::ReadIfReady(IOBuffer* buf,
43 int buf_len,
44 CompletionOnceCallback callback) {
45 DCHECK(!user_read_buf_);
46 if (!read_buffer_)
47 return transport_->ReadIfReady(buf, buf_len, std::move(callback));
48
49 if (buffer_full_)
50 return CopyToCaller(buf, buf_len);
51
52 state_ = STATE_READ;
53 int rv = DoLoop(OK);
54 if (rv == OK) {
55 rv = CopyToCaller(buf, buf_len);
56 } else if (rv == ERR_IO_PENDING) {
57 user_read_callback_ = std::move(callback);
58 }
59 return rv;
60 }
61
DoLoop(int result)62 int ReadBufferingStreamSocket::DoLoop(int result) {
63 int rv = result;
64 do {
65 State current_state = state_;
66 state_ = STATE_NONE;
67 switch (current_state) {
68 case STATE_READ:
69 rv = DoRead();
70 break;
71 case STATE_READ_COMPLETE:
72 rv = DoReadComplete(rv);
73 break;
74 case STATE_NONE:
75 default:
76 NOTREACHED() << "Unexpected state: " << current_state;
77 }
78 } while (rv != ERR_IO_PENDING && state_ != STATE_NONE);
79 return rv;
80 }
81
DoRead()82 int ReadBufferingStreamSocket::DoRead() {
83 DCHECK(read_buffer_);
84 DCHECK(!buffer_full_);
85
86 state_ = STATE_READ_COMPLETE;
87 return transport_->Read(
88 read_buffer_.get(), read_buffer_->RemainingCapacity(),
89 base::BindOnce(&ReadBufferingStreamSocket::OnReadCompleted,
90 base::Unretained(this)));
91 }
92
DoReadComplete(int result)93 int ReadBufferingStreamSocket::DoReadComplete(int result) {
94 state_ = STATE_NONE;
95
96 if (result <= 0)
97 return result;
98
99 read_buffer_->set_offset(read_buffer_->offset() + result);
100 if (read_buffer_->RemainingCapacity() > 0) {
101 // Keep reading until |read_buffer_| is full.
102 state_ = STATE_READ;
103 } else {
104 read_buffer_->set_offset(0);
105 buffer_full_ = true;
106 }
107 return OK;
108 }
109
OnReadCompleted(int result)110 void ReadBufferingStreamSocket::OnReadCompleted(int result) {
111 DCHECK_NE(ERR_IO_PENDING, result);
112 DCHECK(user_read_callback_);
113
114 result = DoLoop(result);
115 if (result == ERR_IO_PENDING)
116 return;
117 if (result == OK && user_read_buf_) {
118 // If the user called Read(), return the data to the caller.
119 result = CopyToCaller(user_read_buf_.get(), user_read_buf_len_);
120 user_read_buf_ = nullptr;
121 user_read_buf_len_ = 0;
122 }
123 std::move(user_read_callback_).Run(result);
124 }
125
CopyToCaller(IOBuffer * buf,int buf_len)126 int ReadBufferingStreamSocket::CopyToCaller(IOBuffer* buf, int buf_len) {
127 DCHECK(read_buffer_);
128 DCHECK(buffer_full_);
129
130 buf_len = std::min(buf_len, read_buffer_->RemainingCapacity());
131 memcpy(buf->data(), read_buffer_->data(), buf_len);
132 read_buffer_->set_offset(read_buffer_->offset() + buf_len);
133 if (read_buffer_->RemainingCapacity() == 0) {
134 read_buffer_ = nullptr;
135 buffer_full_ = false;
136 }
137 return buf_len;
138 }
139
140 } // namespace net
141