• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_deflate_stream.h"
6 
7 #include <algorithm>
8 #include <string>
9 
10 #include "base/bind.h"
11 #include "base/logging.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/scoped_vector.h"
15 #include "net/base/completion_callback.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h"
18 #include "net/websockets/websocket_deflate_predictor.h"
19 #include "net/websockets/websocket_deflater.h"
20 #include "net/websockets/websocket_errors.h"
21 #include "net/websockets/websocket_frame.h"
22 #include "net/websockets/websocket_inflater.h"
23 #include "net/websockets/websocket_stream.h"
24 
25 class GURL;
26 
27 namespace net {
28 
29 namespace {
30 
31 const int kWindowBits = 15;
32 const size_t kChunkSize = 4 * 1024;
33 
34 }  // namespace
35 
WebSocketDeflateStream(scoped_ptr<WebSocketStream> stream,WebSocketDeflater::ContextTakeOverMode mode,int client_window_bits,scoped_ptr<WebSocketDeflatePredictor> predictor)36 WebSocketDeflateStream::WebSocketDeflateStream(
37     scoped_ptr<WebSocketStream> stream,
38     WebSocketDeflater::ContextTakeOverMode mode,
39     int client_window_bits,
40     scoped_ptr<WebSocketDeflatePredictor> predictor)
41     : stream_(stream.Pass()),
42       deflater_(mode),
43       inflater_(kChunkSize, kChunkSize),
44       reading_state_(NOT_READING),
45       writing_state_(NOT_WRITING),
46       current_reading_opcode_(WebSocketFrameHeader::kOpCodeText),
47       current_writing_opcode_(WebSocketFrameHeader::kOpCodeText),
48       predictor_(predictor.Pass()) {
49   DCHECK(stream_);
50   DCHECK_GE(client_window_bits, 8);
51   DCHECK_LE(client_window_bits, 15);
52   deflater_.Initialize(client_window_bits);
53   inflater_.Initialize(kWindowBits);
54 }
55 
~WebSocketDeflateStream()56 WebSocketDeflateStream::~WebSocketDeflateStream() {}
57 
ReadFrames(ScopedVector<WebSocketFrame> * frames,const CompletionCallback & callback)58 int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames,
59                                        const CompletionCallback& callback) {
60   int result = stream_->ReadFrames(
61       frames,
62       base::Bind(&WebSocketDeflateStream::OnReadComplete,
63                  base::Unretained(this),
64                  base::Unretained(frames),
65                  callback));
66   if (result < 0)
67     return result;
68   DCHECK_EQ(OK, result);
69   DCHECK(!frames->empty());
70 
71   return InflateAndReadIfNecessary(frames, callback);
72 }
73 
WriteFrames(ScopedVector<WebSocketFrame> * frames,const CompletionCallback & callback)74 int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames,
75                                         const CompletionCallback& callback) {
76   int result = Deflate(frames);
77   if (result != OK)
78     return result;
79   if (frames->empty())
80     return OK;
81   return stream_->WriteFrames(frames, callback);
82 }
83 
Close()84 void WebSocketDeflateStream::Close() { stream_->Close(); }
85 
GetSubProtocol() const86 std::string WebSocketDeflateStream::GetSubProtocol() const {
87   return stream_->GetSubProtocol();
88 }
89 
GetExtensions() const90 std::string WebSocketDeflateStream::GetExtensions() const {
91   return stream_->GetExtensions();
92 }
93 
OnReadComplete(ScopedVector<WebSocketFrame> * frames,const CompletionCallback & callback,int result)94 void WebSocketDeflateStream::OnReadComplete(
95     ScopedVector<WebSocketFrame>* frames,
96     const CompletionCallback& callback,
97     int result) {
98   if (result != OK) {
99     frames->clear();
100     callback.Run(result);
101     return;
102   }
103 
104   int r = InflateAndReadIfNecessary(frames, callback);
105   if (r != ERR_IO_PENDING)
106     callback.Run(r);
107 }
108 
Deflate(ScopedVector<WebSocketFrame> * frames)109 int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) {
110   ScopedVector<WebSocketFrame> frames_to_write;
111   // Store frames of the currently processed message if writing_state_ equals to
112   // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
113   ScopedVector<WebSocketFrame> frames_of_message;
114   for (size_t i = 0; i < frames->size(); ++i) {
115     DCHECK(!(*frames)[i]->header.reserved1);
116     if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
117       frames_to_write.push_back((*frames)[i]);
118       (*frames)[i] = NULL;
119       continue;
120     }
121     if (writing_state_ == NOT_WRITING)
122       OnMessageStart(*frames, i);
123 
124     scoped_ptr<WebSocketFrame> frame((*frames)[i]);
125     (*frames)[i] = NULL;
126     predictor_->RecordInputDataFrame(frame.get());
127 
128     if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
129       if (frame->header.final)
130         writing_state_ = NOT_WRITING;
131       predictor_->RecordWrittenDataFrame(frame.get());
132       frames_to_write.push_back(frame.release());
133       current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
134     } else {
135       if (frame->data.get() &&
136           !deflater_.AddBytes(frame->data->data(),
137                               frame->header.payload_length)) {
138         DVLOG(1) << "WebSocket protocol error. "
139                  << "deflater_.AddBytes() returns an error.";
140         return ERR_WS_PROTOCOL_ERROR;
141       }
142       if (frame->header.final && !deflater_.Finish()) {
143         DVLOG(1) << "WebSocket protocol error. "
144                  << "deflater_.Finish() returns an error.";
145         return ERR_WS_PROTOCOL_ERROR;
146       }
147 
148       if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
149         if (deflater_.CurrentOutputSize() >= kChunkSize ||
150             frame->header.final) {
151           int result = AppendCompressedFrame(frame->header, &frames_to_write);
152           if (result != OK)
153             return result;
154         }
155         if (frame->header.final)
156           writing_state_ = NOT_WRITING;
157       } else {
158         DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
159         bool final = frame->header.final;
160         frames_of_message.push_back(frame.release());
161         if (final) {
162           int result = AppendPossiblyCompressedMessage(&frames_of_message,
163                                                        &frames_to_write);
164           if (result != OK)
165             return result;
166           frames_of_message.clear();
167           writing_state_ = NOT_WRITING;
168         }
169       }
170     }
171   }
172   DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
173   frames->swap(frames_to_write);
174   return OK;
175 }
176 
OnMessageStart(const ScopedVector<WebSocketFrame> & frames,size_t index)177 void WebSocketDeflateStream::OnMessageStart(
178     const ScopedVector<WebSocketFrame>& frames, size_t index) {
179   WebSocketFrame* frame = frames[index];
180   current_writing_opcode_ = frame->header.opcode;
181   DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
182          current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
183   WebSocketDeflatePredictor::Result prediction =
184       predictor_->Predict(frames, index);
185 
186   switch (prediction) {
187     case WebSocketDeflatePredictor::DEFLATE:
188       writing_state_ = WRITING_COMPRESSED_MESSAGE;
189       return;
190     case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
191       writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
192       return;
193     case WebSocketDeflatePredictor::TRY_DEFLATE:
194       writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
195       return;
196   }
197   NOTREACHED();
198 }
199 
AppendCompressedFrame(const WebSocketFrameHeader & header,ScopedVector<WebSocketFrame> * frames_to_write)200 int WebSocketDeflateStream::AppendCompressedFrame(
201     const WebSocketFrameHeader& header,
202     ScopedVector<WebSocketFrame>* frames_to_write) {
203   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
204   scoped_refptr<IOBufferWithSize> compressed_payload =
205       deflater_.GetOutput(deflater_.CurrentOutputSize());
206   if (!compressed_payload.get()) {
207     DVLOG(1) << "WebSocket protocol error. "
208              << "deflater_.GetOutput() returns an error.";
209     return ERR_WS_PROTOCOL_ERROR;
210   }
211   scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
212   compressed->header.CopyFrom(header);
213   compressed->header.opcode = opcode;
214   compressed->header.final = header.final;
215   compressed->header.reserved1 =
216       (opcode != WebSocketFrameHeader::kOpCodeContinuation);
217   compressed->data = compressed_payload;
218   compressed->header.payload_length = compressed_payload->size();
219 
220   current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
221   predictor_->RecordWrittenDataFrame(compressed.get());
222   frames_to_write->push_back(compressed.release());
223   return OK;
224 }
225 
AppendPossiblyCompressedMessage(ScopedVector<WebSocketFrame> * frames,ScopedVector<WebSocketFrame> * frames_to_write)226 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
227     ScopedVector<WebSocketFrame>* frames,
228     ScopedVector<WebSocketFrame>* frames_to_write) {
229   DCHECK(!frames->empty());
230 
231   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
232   scoped_refptr<IOBufferWithSize> compressed_payload =
233       deflater_.GetOutput(deflater_.CurrentOutputSize());
234   if (!compressed_payload.get()) {
235     DVLOG(1) << "WebSocket protocol error. "
236              << "deflater_.GetOutput() returns an error.";
237     return ERR_WS_PROTOCOL_ERROR;
238   }
239 
240   uint64 original_payload_length = 0;
241   for (size_t i = 0; i < frames->size(); ++i) {
242     WebSocketFrame* frame = (*frames)[i];
243     // Asserts checking that frames represent one whole data message.
244     DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
245     DCHECK_EQ(i == 0,
246               WebSocketFrameHeader::kOpCodeContinuation !=
247               frame->header.opcode);
248     DCHECK_EQ(i == frames->size() - 1, frame->header.final);
249     original_payload_length += frame->header.payload_length;
250   }
251   if (original_payload_length <=
252       static_cast<uint64>(compressed_payload->size())) {
253     // Compression is not effective. Use the original frames.
254     for (size_t i = 0; i < frames->size(); ++i) {
255       WebSocketFrame* frame = (*frames)[i];
256       frames_to_write->push_back(frame);
257       predictor_->RecordWrittenDataFrame(frame);
258       (*frames)[i] = NULL;
259     }
260     frames->weak_clear();
261     return OK;
262   }
263   scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
264   compressed->header.CopyFrom((*frames)[0]->header);
265   compressed->header.opcode = opcode;
266   compressed->header.final = true;
267   compressed->header.reserved1 = true;
268   compressed->data = compressed_payload;
269   compressed->header.payload_length = compressed_payload->size();
270 
271   predictor_->RecordWrittenDataFrame(compressed.get());
272   frames_to_write->push_back(compressed.release());
273   return OK;
274 }
275 
Inflate(ScopedVector<WebSocketFrame> * frames)276 int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) {
277   ScopedVector<WebSocketFrame> frames_to_output;
278   ScopedVector<WebSocketFrame> frames_passed;
279   frames->swap(frames_passed);
280   for (size_t i = 0; i < frames_passed.size(); ++i) {
281     scoped_ptr<WebSocketFrame> frame(frames_passed[i]);
282     frames_passed[i] = NULL;
283     DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
284              << " final=" << frame->header.final
285              << " reserved1=" << frame->header.reserved1
286              << " payload_length=" << frame->header.payload_length;
287 
288     if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
289       frames_to_output.push_back(frame.release());
290       continue;
291     }
292 
293     if (reading_state_ == NOT_READING) {
294       if (frame->header.reserved1)
295         reading_state_ = READING_COMPRESSED_MESSAGE;
296       else
297         reading_state_ = READING_UNCOMPRESSED_MESSAGE;
298       current_reading_opcode_ = frame->header.opcode;
299     } else {
300       if (frame->header.reserved1) {
301         DVLOG(1) << "WebSocket protocol error. "
302                  << "Receiving a non-first frame with RSV1 flag set.";
303         return ERR_WS_PROTOCOL_ERROR;
304       }
305     }
306 
307     if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
308       if (frame->header.final)
309         reading_state_ = NOT_READING;
310       current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
311       frames_to_output.push_back(frame.release());
312     } else {
313       DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
314       if (frame->data.get() &&
315           !inflater_.AddBytes(frame->data->data(),
316                               frame->header.payload_length)) {
317         DVLOG(1) << "WebSocket protocol error. "
318                  << "inflater_.AddBytes() returns an error.";
319         return ERR_WS_PROTOCOL_ERROR;
320       }
321       if (frame->header.final) {
322         if (!inflater_.Finish()) {
323           DVLOG(1) << "WebSocket protocol error. "
324                    << "inflater_.Finish() returns an error.";
325           return ERR_WS_PROTOCOL_ERROR;
326         }
327       }
328       // TODO(yhirano): Many frames can be generated by the inflater and
329       // memory consumption can grow.
330       // We could avoid it, but avoiding it makes this class much more
331       // complicated.
332       while (inflater_.CurrentOutputSize() >= kChunkSize ||
333              frame->header.final) {
334         size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
335         scoped_ptr<WebSocketFrame> inflated(
336             new WebSocketFrame(WebSocketFrameHeader::kOpCodeText));
337         scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
338         bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
339         if (!data.get()) {
340           DVLOG(1) << "WebSocket protocol error. "
341                    << "inflater_.GetOutput() returns an error.";
342           return ERR_WS_PROTOCOL_ERROR;
343         }
344         inflated->header.CopyFrom(frame->header);
345         inflated->header.opcode = current_reading_opcode_;
346         inflated->header.final = is_final;
347         inflated->header.reserved1 = false;
348         inflated->data = data;
349         inflated->header.payload_length = data->size();
350         DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
351                  << " final=" << inflated->header.final
352                  << " reserved1=" << inflated->header.reserved1
353                  << " payload_length=" << inflated->header.payload_length;
354         frames_to_output.push_back(inflated.release());
355         current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
356         if (is_final)
357           break;
358       }
359       if (frame->header.final)
360         reading_state_ = NOT_READING;
361     }
362   }
363   frames->swap(frames_to_output);
364   return frames->empty() ? ERR_IO_PENDING : OK;
365 }
366 
InflateAndReadIfNecessary(ScopedVector<WebSocketFrame> * frames,const CompletionCallback & callback)367 int WebSocketDeflateStream::InflateAndReadIfNecessary(
368     ScopedVector<WebSocketFrame>* frames,
369     const CompletionCallback& callback) {
370   int result = Inflate(frames);
371   while (result == ERR_IO_PENDING) {
372     DCHECK(frames->empty());
373 
374     result = stream_->ReadFrames(
375         frames,
376         base::Bind(&WebSocketDeflateStream::OnReadComplete,
377                    base::Unretained(this),
378                    base::Unretained(frames),
379                    callback));
380     if (result < 0)
381       break;
382     DCHECK_EQ(OK, result);
383     DCHECK(!frames->empty());
384 
385     result = Inflate(frames);
386   }
387   if (result < 0)
388     frames->clear();
389   return result;
390 }
391 
392 }  // namespace net
393