1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/websockets/websocket_inflater.h"
11
12 #include <string.h>
13
14 #include <algorithm>
15 #include <vector>
16
17 #include "base/check.h"
18 #include "base/check_op.h"
19 #include "net/base/io_buffer.h"
20 #include "third_party/zlib/zlib.h"
21
22 namespace net {
23
24 namespace {
25
26 class ShrinkableIOBufferWithSize : public IOBufferWithSize {
27 public:
ShrinkableIOBufferWithSize(size_t size)28 explicit ShrinkableIOBufferWithSize(size_t size) : IOBufferWithSize(size) {}
29
Shrink(int new_size)30 void Shrink(int new_size) {
31 CHECK_GE(new_size, 0);
32 CHECK_LE(new_size, size_);
33 size_ = new_size;
34 }
35
36 private:
37 ~ShrinkableIOBufferWithSize() override = default;
38 };
39
40 } // namespace
41
WebSocketInflater()42 WebSocketInflater::WebSocketInflater()
43 : input_queue_(kDefaultInputIOBufferCapacity),
44 output_buffer_(kDefaultBufferCapacity) {}
45
WebSocketInflater(size_t input_queue_capacity,size_t output_buffer_capacity)46 WebSocketInflater::WebSocketInflater(size_t input_queue_capacity,
47 size_t output_buffer_capacity)
48 : input_queue_(input_queue_capacity),
49 output_buffer_(output_buffer_capacity) {
50 DCHECK_GT(input_queue_capacity, 0u);
51 DCHECK_GT(output_buffer_capacity, 0u);
52 }
53
Initialize(int window_bits)54 bool WebSocketInflater::Initialize(int window_bits) {
55 DCHECK_LE(8, window_bits);
56 DCHECK_GE(15, window_bits);
57 stream_ = std::make_unique<z_stream>();
58 memset(stream_.get(), 0, sizeof(*stream_));
59 int result = inflateInit2(stream_.get(), -window_bits);
60 if (result != Z_OK) {
61 inflateEnd(stream_.get());
62 stream_.reset();
63 return false;
64 }
65 return true;
66 }
67
~WebSocketInflater()68 WebSocketInflater::~WebSocketInflater() {
69 if (stream_) {
70 inflateEnd(stream_.get());
71 stream_.reset();
72 }
73 }
74
AddBytes(const char * data,size_t size)75 bool WebSocketInflater::AddBytes(const char* data, size_t size) {
76 if (!size)
77 return true;
78
79 if (!input_queue_.IsEmpty()) {
80 // choked
81 input_queue_.Push(data, size);
82 return true;
83 }
84
85 int result = InflateWithFlush(data, size);
86 if (stream_->avail_in > 0)
87 input_queue_.Push(&data[size - stream_->avail_in], stream_->avail_in);
88
89 return result == Z_OK || result == Z_BUF_ERROR;
90 }
91
Finish()92 bool WebSocketInflater::Finish() {
93 return AddBytes("\x00\x00\xff\xff", 4);
94 }
95
GetOutput(size_t size)96 scoped_refptr<IOBufferWithSize> WebSocketInflater::GetOutput(size_t size) {
97 auto buffer = base::MakeRefCounted<ShrinkableIOBufferWithSize>(size);
98 size_t num_bytes_copied = 0;
99
100 while (num_bytes_copied < size && output_buffer_.Size() > 0) {
101 size_t num_bytes_to_copy =
102 std::min(output_buffer_.Size(), size - num_bytes_copied);
103 output_buffer_.Read(&buffer->data()[num_bytes_copied], num_bytes_to_copy);
104 num_bytes_copied += num_bytes_to_copy;
105 int result = InflateChokedInput();
106 if (result != Z_OK && result != Z_BUF_ERROR)
107 return nullptr;
108 }
109 buffer->Shrink(num_bytes_copied);
110 return buffer;
111 }
112
InflateWithFlush(const char * next_in,size_t avail_in)113 int WebSocketInflater::InflateWithFlush(const char* next_in, size_t avail_in) {
114 int result = Inflate(next_in, avail_in, Z_NO_FLUSH);
115 if (result != Z_OK && result != Z_BUF_ERROR)
116 return result;
117
118 if (CurrentOutputSize() > 0)
119 return result;
120 // CurrentOutputSize() == 0 means there is no data to be output,
121 // so we should make sure it by using Z_SYNC_FLUSH.
122 return Inflate(reinterpret_cast<const char*>(stream_->next_in),
123 stream_->avail_in,
124 Z_SYNC_FLUSH);
125 }
126
Inflate(const char * next_in,size_t avail_in,int flush)127 int WebSocketInflater::Inflate(const char* next_in,
128 size_t avail_in,
129 int flush) {
130 stream_->next_in = reinterpret_cast<Bytef*>(const_cast<char*>(next_in));
131 stream_->avail_in = avail_in;
132
133 int result = Z_BUF_ERROR;
134 do {
135 std::pair<char*, size_t> tail = output_buffer_.GetTail();
136 if (!tail.second)
137 break;
138
139 stream_->next_out = reinterpret_cast<Bytef*>(tail.first);
140 stream_->avail_out = tail.second;
141 result = inflate(stream_.get(), flush);
142 output_buffer_.AdvanceTail(tail.second - stream_->avail_out);
143 if (result == Z_STREAM_END) {
144 // Received a block with BFINAL set to 1. Reset the decompression state.
145 result = inflateReset(stream_.get());
146 } else if (tail.second == stream_->avail_out) {
147 break;
148 }
149 } while (result == Z_OK || result == Z_BUF_ERROR);
150 return result;
151 }
152
InflateChokedInput()153 int WebSocketInflater::InflateChokedInput() {
154 if (input_queue_.IsEmpty())
155 return InflateWithFlush(nullptr, 0);
156
157 int result = Z_BUF_ERROR;
158 while (!input_queue_.IsEmpty()) {
159 std::pair<char*, size_t> top = input_queue_.Top();
160
161 result = InflateWithFlush(top.first, top.second);
162 input_queue_.Consume(top.second - stream_->avail_in);
163
164 if (result != Z_OK && result != Z_BUF_ERROR)
165 return result;
166
167 if (stream_->avail_in > 0) {
168 // There are some data which are not consumed.
169 break;
170 }
171 }
172 return result;
173 }
174
OutputBuffer(size_t capacity)175 WebSocketInflater::OutputBuffer::OutputBuffer(size_t capacity)
176 : capacity_(capacity),
177 buffer_(capacity_ + 1) // 1 for sentinel
178 {}
179
180 WebSocketInflater::OutputBuffer::~OutputBuffer() = default;
181
Size() const182 size_t WebSocketInflater::OutputBuffer::Size() const {
183 return (tail_ + buffer_.size() - head_) % buffer_.size();
184 }
185
GetTail()186 std::pair<char*, size_t> WebSocketInflater::OutputBuffer::GetTail() {
187 DCHECK_LT(tail_, buffer_.size());
188 return std::pair(&buffer_[tail_],
189 std::min(capacity_ - Size(), buffer_.size() - tail_));
190 }
191
Read(char * dest,size_t size)192 void WebSocketInflater::OutputBuffer::Read(char* dest, size_t size) {
193 DCHECK_LE(size, Size());
194
195 size_t num_bytes_copied = 0;
196 if (tail_ < head_) {
197 size_t num_bytes_to_copy = std::min(size, buffer_.size() - head_);
198 DCHECK_LT(head_, buffer_.size());
199 memcpy(&dest[num_bytes_copied], &buffer_[head_], num_bytes_to_copy);
200 AdvanceHead(num_bytes_to_copy);
201 num_bytes_copied += num_bytes_to_copy;
202 }
203
204 if (num_bytes_copied == size)
205 return;
206 DCHECK_LE(head_, tail_);
207 size_t num_bytes_to_copy = size - num_bytes_copied;
208 DCHECK_LE(num_bytes_to_copy, tail_ - head_);
209 DCHECK_LT(head_, buffer_.size());
210 memcpy(&dest[num_bytes_copied], &buffer_[head_], num_bytes_to_copy);
211 AdvanceHead(num_bytes_to_copy);
212 num_bytes_copied += num_bytes_to_copy;
213 DCHECK_EQ(size, num_bytes_copied);
214 return;
215 }
216
AdvanceHead(size_t advance)217 void WebSocketInflater::OutputBuffer::AdvanceHead(size_t advance) {
218 DCHECK_LE(advance, Size());
219 head_ = (head_ + advance) % buffer_.size();
220 }
221
AdvanceTail(size_t advance)222 void WebSocketInflater::OutputBuffer::AdvanceTail(size_t advance) {
223 DCHECK_LE(advance + Size(), capacity_);
224 tail_ = (tail_ + advance) % buffer_.size();
225 }
226
InputQueue(size_t capacity)227 WebSocketInflater::InputQueue::InputQueue(size_t capacity)
228 : capacity_(capacity) {}
229
230 WebSocketInflater::InputQueue::~InputQueue() = default;
231
Top()232 std::pair<char*, size_t> WebSocketInflater::InputQueue::Top() {
233 DCHECK(!IsEmpty());
234 if (buffers_.size() == 1) {
235 return std::pair(&buffers_.front()->data()[head_of_first_buffer_],
236 tail_of_last_buffer_ - head_of_first_buffer_);
237 }
238 return std::pair(&buffers_.front()->data()[head_of_first_buffer_],
239 capacity_ - head_of_first_buffer_);
240 }
241
Push(const char * data,size_t size)242 void WebSocketInflater::InputQueue::Push(const char* data, size_t size) {
243 if (!size)
244 return;
245
246 size_t num_copied_bytes = 0;
247 if (!IsEmpty())
248 num_copied_bytes += PushToLastBuffer(data, size);
249
250 while (num_copied_bytes < size) {
251 DCHECK(IsEmpty() || tail_of_last_buffer_ == capacity_);
252
253 buffers_.push_back(base::MakeRefCounted<IOBufferWithSize>(capacity_));
254 tail_of_last_buffer_ = 0;
255 num_copied_bytes +=
256 PushToLastBuffer(&data[num_copied_bytes], size - num_copied_bytes);
257 }
258 }
259
Consume(size_t size)260 void WebSocketInflater::InputQueue::Consume(size_t size) {
261 DCHECK(!IsEmpty());
262 DCHECK_LE(size + head_of_first_buffer_, capacity_);
263
264 head_of_first_buffer_ += size;
265 if (head_of_first_buffer_ == capacity_) {
266 buffers_.pop_front();
267 head_of_first_buffer_ = 0;
268 }
269 if (buffers_.size() == 1 && head_of_first_buffer_ == tail_of_last_buffer_) {
270 buffers_.pop_front();
271 head_of_first_buffer_ = 0;
272 tail_of_last_buffer_ = 0;
273 }
274 }
275
PushToLastBuffer(const char * data,size_t size)276 size_t WebSocketInflater::InputQueue::PushToLastBuffer(const char* data,
277 size_t size) {
278 DCHECK(!IsEmpty());
279 size_t num_bytes_to_copy = std::min(size, capacity_ - tail_of_last_buffer_);
280 if (!num_bytes_to_copy)
281 return 0;
282 IOBufferWithSize* buffer = buffers_.back().get();
283 memcpy(&buffer->data()[tail_of_last_buffer_], data, num_bytes_to_copy);
284 tail_of_last_buffer_ += num_bytes_to_copy;
285 return num_bytes_to_copy;
286 }
287
288 } // namespace net
289