• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9 
10 #include "net/websockets/websocket_inflater.h"
11 
12 #include <string.h>
13 
14 #include <algorithm>
15 #include <vector>
16 
17 #include "base/check.h"
18 #include "base/check_op.h"
19 #include "net/base/io_buffer.h"
20 #include "third_party/zlib/zlib.h"
21 
22 namespace net {
23 
24 namespace {
25 
26 class ShrinkableIOBufferWithSize : public IOBufferWithSize {
27  public:
ShrinkableIOBufferWithSize(size_t size)28   explicit ShrinkableIOBufferWithSize(size_t size) : IOBufferWithSize(size) {}
29 
Shrink(int new_size)30   void Shrink(int new_size) {
31     CHECK_GE(new_size, 0);
32     CHECK_LE(new_size, size_);
33     size_ = new_size;
34   }
35 
36  private:
37   ~ShrinkableIOBufferWithSize() override = default;
38 };
39 
40 }  // namespace
41 
WebSocketInflater()42 WebSocketInflater::WebSocketInflater()
43     : input_queue_(kDefaultInputIOBufferCapacity),
44       output_buffer_(kDefaultBufferCapacity) {}
45 
WebSocketInflater(size_t input_queue_capacity,size_t output_buffer_capacity)46 WebSocketInflater::WebSocketInflater(size_t input_queue_capacity,
47                                      size_t output_buffer_capacity)
48     : input_queue_(input_queue_capacity),
49       output_buffer_(output_buffer_capacity) {
50   DCHECK_GT(input_queue_capacity, 0u);
51   DCHECK_GT(output_buffer_capacity, 0u);
52 }
53 
Initialize(int window_bits)54 bool WebSocketInflater::Initialize(int window_bits) {
55   DCHECK_LE(8, window_bits);
56   DCHECK_GE(15, window_bits);
57   stream_ = std::make_unique<z_stream>();
58   memset(stream_.get(), 0, sizeof(*stream_));
59   int result = inflateInit2(stream_.get(), -window_bits);
60   if (result != Z_OK) {
61     inflateEnd(stream_.get());
62     stream_.reset();
63     return false;
64   }
65   return true;
66 }
67 
~WebSocketInflater()68 WebSocketInflater::~WebSocketInflater() {
69   if (stream_) {
70     inflateEnd(stream_.get());
71     stream_.reset();
72   }
73 }
74 
AddBytes(const char * data,size_t size)75 bool WebSocketInflater::AddBytes(const char* data, size_t size) {
76   if (!size)
77     return true;
78 
79   if (!input_queue_.IsEmpty()) {
80     // choked
81     input_queue_.Push(data, size);
82     return true;
83   }
84 
85   int result = InflateWithFlush(data, size);
86   if (stream_->avail_in > 0)
87     input_queue_.Push(&data[size - stream_->avail_in], stream_->avail_in);
88 
89   return result == Z_OK || result == Z_BUF_ERROR;
90 }
91 
Finish()92 bool WebSocketInflater::Finish() {
93   return AddBytes("\x00\x00\xff\xff", 4);
94 }
95 
GetOutput(size_t size)96 scoped_refptr<IOBufferWithSize> WebSocketInflater::GetOutput(size_t size) {
97   auto buffer = base::MakeRefCounted<ShrinkableIOBufferWithSize>(size);
98   size_t num_bytes_copied = 0;
99 
100   while (num_bytes_copied < size && output_buffer_.Size() > 0) {
101     size_t num_bytes_to_copy =
102         std::min(output_buffer_.Size(), size - num_bytes_copied);
103     output_buffer_.Read(&buffer->data()[num_bytes_copied], num_bytes_to_copy);
104     num_bytes_copied += num_bytes_to_copy;
105     int result = InflateChokedInput();
106     if (result != Z_OK && result != Z_BUF_ERROR)
107       return nullptr;
108   }
109   buffer->Shrink(num_bytes_copied);
110   return buffer;
111 }
112 
InflateWithFlush(const char * next_in,size_t avail_in)113 int WebSocketInflater::InflateWithFlush(const char* next_in, size_t avail_in) {
114   int result = Inflate(next_in, avail_in, Z_NO_FLUSH);
115   if (result != Z_OK && result != Z_BUF_ERROR)
116     return result;
117 
118   if (CurrentOutputSize() > 0)
119     return result;
120   // CurrentOutputSize() == 0 means there is no data to be output,
121   // so we should make sure it by using Z_SYNC_FLUSH.
122   return Inflate(reinterpret_cast<const char*>(stream_->next_in),
123                  stream_->avail_in,
124                  Z_SYNC_FLUSH);
125 }
126 
Inflate(const char * next_in,size_t avail_in,int flush)127 int WebSocketInflater::Inflate(const char* next_in,
128                                size_t avail_in,
129                                int flush) {
130   stream_->next_in = reinterpret_cast<Bytef*>(const_cast<char*>(next_in));
131   stream_->avail_in = avail_in;
132 
133   int result = Z_BUF_ERROR;
134   do {
135     std::pair<char*, size_t> tail = output_buffer_.GetTail();
136     if (!tail.second)
137       break;
138 
139     stream_->next_out = reinterpret_cast<Bytef*>(tail.first);
140     stream_->avail_out = tail.second;
141     result = inflate(stream_.get(), flush);
142     output_buffer_.AdvanceTail(tail.second - stream_->avail_out);
143     if (result == Z_STREAM_END) {
144       // Received a block with BFINAL set to 1. Reset the decompression state.
145       result = inflateReset(stream_.get());
146     } else if (tail.second == stream_->avail_out) {
147       break;
148     }
149   } while (result == Z_OK || result == Z_BUF_ERROR);
150   return result;
151 }
152 
InflateChokedInput()153 int WebSocketInflater::InflateChokedInput() {
154   if (input_queue_.IsEmpty())
155     return InflateWithFlush(nullptr, 0);
156 
157   int result = Z_BUF_ERROR;
158   while (!input_queue_.IsEmpty()) {
159     std::pair<char*, size_t> top = input_queue_.Top();
160 
161     result = InflateWithFlush(top.first, top.second);
162     input_queue_.Consume(top.second - stream_->avail_in);
163 
164     if (result != Z_OK && result != Z_BUF_ERROR)
165       return result;
166 
167     if (stream_->avail_in > 0) {
168       // There are some data which are not consumed.
169       break;
170     }
171   }
172   return result;
173 }
174 
OutputBuffer(size_t capacity)175 WebSocketInflater::OutputBuffer::OutputBuffer(size_t capacity)
176     : capacity_(capacity),
177       buffer_(capacity_ + 1)  // 1 for sentinel
178 {}
179 
180 WebSocketInflater::OutputBuffer::~OutputBuffer() = default;
181 
Size() const182 size_t WebSocketInflater::OutputBuffer::Size() const {
183   return (tail_ + buffer_.size() - head_) % buffer_.size();
184 }
185 
GetTail()186 std::pair<char*, size_t> WebSocketInflater::OutputBuffer::GetTail() {
187   DCHECK_LT(tail_, buffer_.size());
188   return std::pair(&buffer_[tail_],
189                    std::min(capacity_ - Size(), buffer_.size() - tail_));
190 }
191 
Read(char * dest,size_t size)192 void WebSocketInflater::OutputBuffer::Read(char* dest, size_t size) {
193   DCHECK_LE(size, Size());
194 
195   size_t num_bytes_copied = 0;
196   if (tail_ < head_) {
197     size_t num_bytes_to_copy = std::min(size, buffer_.size() - head_);
198     DCHECK_LT(head_, buffer_.size());
199     memcpy(&dest[num_bytes_copied], &buffer_[head_], num_bytes_to_copy);
200     AdvanceHead(num_bytes_to_copy);
201     num_bytes_copied += num_bytes_to_copy;
202   }
203 
204   if (num_bytes_copied == size)
205     return;
206   DCHECK_LE(head_, tail_);
207   size_t num_bytes_to_copy = size - num_bytes_copied;
208   DCHECK_LE(num_bytes_to_copy, tail_ - head_);
209   DCHECK_LT(head_, buffer_.size());
210   memcpy(&dest[num_bytes_copied], &buffer_[head_], num_bytes_to_copy);
211   AdvanceHead(num_bytes_to_copy);
212   num_bytes_copied += num_bytes_to_copy;
213   DCHECK_EQ(size, num_bytes_copied);
214   return;
215 }
216 
AdvanceHead(size_t advance)217 void WebSocketInflater::OutputBuffer::AdvanceHead(size_t advance) {
218   DCHECK_LE(advance, Size());
219   head_ = (head_ + advance) % buffer_.size();
220 }
221 
AdvanceTail(size_t advance)222 void WebSocketInflater::OutputBuffer::AdvanceTail(size_t advance) {
223   DCHECK_LE(advance + Size(), capacity_);
224   tail_ = (tail_ + advance) % buffer_.size();
225 }
226 
InputQueue(size_t capacity)227 WebSocketInflater::InputQueue::InputQueue(size_t capacity)
228     : capacity_(capacity) {}
229 
230 WebSocketInflater::InputQueue::~InputQueue() = default;
231 
Top()232 std::pair<char*, size_t> WebSocketInflater::InputQueue::Top() {
233   DCHECK(!IsEmpty());
234   if (buffers_.size() == 1) {
235     return std::pair(&buffers_.front()->data()[head_of_first_buffer_],
236                      tail_of_last_buffer_ - head_of_first_buffer_);
237   }
238   return std::pair(&buffers_.front()->data()[head_of_first_buffer_],
239                    capacity_ - head_of_first_buffer_);
240 }
241 
Push(const char * data,size_t size)242 void WebSocketInflater::InputQueue::Push(const char* data, size_t size) {
243   if (!size)
244     return;
245 
246   size_t num_copied_bytes = 0;
247   if (!IsEmpty())
248     num_copied_bytes += PushToLastBuffer(data, size);
249 
250   while (num_copied_bytes < size) {
251     DCHECK(IsEmpty() || tail_of_last_buffer_ == capacity_);
252 
253     buffers_.push_back(base::MakeRefCounted<IOBufferWithSize>(capacity_));
254     tail_of_last_buffer_ = 0;
255     num_copied_bytes +=
256         PushToLastBuffer(&data[num_copied_bytes], size - num_copied_bytes);
257   }
258 }
259 
Consume(size_t size)260 void WebSocketInflater::InputQueue::Consume(size_t size) {
261   DCHECK(!IsEmpty());
262   DCHECK_LE(size + head_of_first_buffer_, capacity_);
263 
264   head_of_first_buffer_ += size;
265   if (head_of_first_buffer_ == capacity_) {
266     buffers_.pop_front();
267     head_of_first_buffer_ = 0;
268   }
269   if (buffers_.size() == 1 && head_of_first_buffer_ == tail_of_last_buffer_) {
270     buffers_.pop_front();
271     head_of_first_buffer_ = 0;
272     tail_of_last_buffer_ = 0;
273   }
274 }
275 
PushToLastBuffer(const char * data,size_t size)276 size_t WebSocketInflater::InputQueue::PushToLastBuffer(const char* data,
277                                                        size_t size) {
278   DCHECK(!IsEmpty());
279   size_t num_bytes_to_copy = std::min(size, capacity_ - tail_of_last_buffer_);
280   if (!num_bytes_to_copy)
281     return 0;
282   IOBufferWithSize* buffer = buffers_.back().get();
283   memcpy(&buffer->data()[tail_of_last_buffer_], data, num_bytes_to_copy);
284   tail_of_last_buffer_ += num_bytes_to_copy;
285   return num_bytes_to_copy;
286 }
287 
288 }  // namespace net
289