• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_deflate_stream.h"
6 
7 #include <stdint.h>
8 
9 #include <algorithm>
10 #include <ostream>
11 #include <string>
12 #include <utility>
13 #include <vector>
14 
15 #include "base/check.h"
16 #include "base/check_op.h"
17 #include "base/containers/span.h"
18 #include "base/functional/bind.h"
19 #include "base/functional/callback.h"
20 #include "base/logging.h"
21 #include "base/memory/scoped_refptr.h"
22 #include "base/notreached.h"
23 #include "net/base/io_buffer.h"
24 #include "net/base/net_errors.h"
25 #include "net/websockets/websocket_deflate_parameters.h"
26 #include "net/websockets/websocket_deflate_predictor.h"
27 #include "net/websockets/websocket_deflater.h"
28 #include "net/websockets/websocket_frame.h"
29 #include "net/websockets/websocket_inflater.h"
30 #include "net/websockets/websocket_stream.h"
31 
32 namespace net {
33 class NetLogWithSource;
34 
35 namespace {
36 
37 constexpr int kWindowBits = 15;
38 constexpr size_t kChunkSize = 4 * 1024;
39 
40 }  // namespace
41 
WebSocketDeflateStream(std::unique_ptr<WebSocketStream> stream,const WebSocketDeflateParameters & params,std::unique_ptr<WebSocketDeflatePredictor> predictor)42 WebSocketDeflateStream::WebSocketDeflateStream(
43     std::unique_ptr<WebSocketStream> stream,
44     const WebSocketDeflateParameters& params,
45     std::unique_ptr<WebSocketDeflatePredictor> predictor)
46     : stream_(std::move(stream)),
47       deflater_(params.client_context_take_over_mode()),
48       inflater_(kChunkSize, kChunkSize),
49       predictor_(std::move(predictor)) {
50   DCHECK(stream_);
51   DCHECK(params.IsValidAsResponse());
52   int client_max_window_bits = 15;
53   if (params.is_client_max_window_bits_specified()) {
54     DCHECK(params.has_client_max_window_bits_value());
55     client_max_window_bits = params.client_max_window_bits();
56   }
57   deflater_.Initialize(client_max_window_bits);
58   inflater_.Initialize(kWindowBits);
59 }
60 
61 WebSocketDeflateStream::~WebSocketDeflateStream() = default;
62 
ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)63 int WebSocketDeflateStream::ReadFrames(
64     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
65     CompletionOnceCallback callback) {
66   read_callback_ = std::move(callback);
67   inflater_outputs_.clear();
68   int result = stream_->ReadFrames(
69       frames, base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
70                              base::Unretained(this), base::Unretained(frames)));
71   if (result < 0)
72     return result;
73   DCHECK_EQ(OK, result);
74   DCHECK(!frames->empty());
75 
76   return InflateAndReadIfNecessary(frames);
77 }
78 
WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)79 int WebSocketDeflateStream::WriteFrames(
80     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
81     CompletionOnceCallback callback) {
82   deflater_outputs_.clear();
83   int result = Deflate(frames);
84   if (result != OK)
85     return result;
86   if (frames->empty())
87     return OK;
88   return stream_->WriteFrames(frames, std::move(callback));
89 }
90 
Close()91 void WebSocketDeflateStream::Close() { stream_->Close(); }
92 
GetSubProtocol() const93 std::string WebSocketDeflateStream::GetSubProtocol() const {
94   return stream_->GetSubProtocol();
95 }
96 
GetExtensions() const97 std::string WebSocketDeflateStream::GetExtensions() const {
98   return stream_->GetExtensions();
99 }
100 
GetNetLogWithSource() const101 const NetLogWithSource& WebSocketDeflateStream::GetNetLogWithSource() const {
102   return stream_->GetNetLogWithSource();
103 }
104 
OnReadComplete(std::vector<std::unique_ptr<WebSocketFrame>> * frames,int result)105 void WebSocketDeflateStream::OnReadComplete(
106     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
107     int result) {
108   if (result != OK) {
109     frames->clear();
110     std::move(read_callback_).Run(result);
111     return;
112   }
113 
114   int r = InflateAndReadIfNecessary(frames);
115   if (r != ERR_IO_PENDING)
116     std::move(read_callback_).Run(r);
117 }
118 
Deflate(std::vector<std::unique_ptr<WebSocketFrame>> * frames)119 int WebSocketDeflateStream::Deflate(
120     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
121   std::vector<std::unique_ptr<WebSocketFrame>> frames_to_write;
122   // Store frames of the currently processed message if writing_state_ equals to
123   // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
124   std::vector<std::unique_ptr<WebSocketFrame>> frames_of_message;
125   for (size_t i = 0; i < frames->size(); ++i) {
126     DCHECK(!(*frames)[i]->header.reserved1);
127     if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
128       frames_to_write.push_back(std::move((*frames)[i]));
129       continue;
130     }
131     if (writing_state_ == NOT_WRITING)
132       OnMessageStart(*frames, i);
133 
134     std::unique_ptr<WebSocketFrame> frame(std::move((*frames)[i]));
135     predictor_->RecordInputDataFrame(frame.get());
136 
137     if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
138       if (frame->header.final)
139         writing_state_ = NOT_WRITING;
140       predictor_->RecordWrittenDataFrame(frame.get());
141       frames_to_write.push_back(std::move(frame));
142       current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
143     } else {
144       if (!frame->payload.empty() &&
145           !deflater_.AddBytes(base::as_chars(frame->payload).data(),
146                               frame->payload.size())) {
147         DVLOG(1) << "WebSocket protocol error. "
148                  << "deflater_.AddBytes() returns an error.";
149         return ERR_WS_PROTOCOL_ERROR;
150       }
151       if (frame->header.final && !deflater_.Finish()) {
152         DVLOG(1) << "WebSocket protocol error. "
153                  << "deflater_.Finish() returns an error.";
154         return ERR_WS_PROTOCOL_ERROR;
155       }
156 
157       if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
158         if (deflater_.CurrentOutputSize() >= kChunkSize ||
159             frame->header.final) {
160           int result = AppendCompressedFrame(frame->header, &frames_to_write);
161           if (result != OK)
162             return result;
163         }
164         if (frame->header.final)
165           writing_state_ = NOT_WRITING;
166       } else {
167         DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
168         bool final = frame->header.final;
169         frames_of_message.push_back(std::move(frame));
170         if (final) {
171           int result = AppendPossiblyCompressedMessage(&frames_of_message,
172                                                        &frames_to_write);
173           if (result != OK)
174             return result;
175           frames_of_message.clear();
176           writing_state_ = NOT_WRITING;
177         }
178       }
179     }
180   }
181   DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
182   frames->swap(frames_to_write);
183   return OK;
184 }
185 
OnMessageStart(const std::vector<std::unique_ptr<WebSocketFrame>> & frames,size_t index)186 void WebSocketDeflateStream::OnMessageStart(
187     const std::vector<std::unique_ptr<WebSocketFrame>>& frames,
188     size_t index) {
189   WebSocketFrame* frame = frames[index].get();
190   current_writing_opcode_ = frame->header.opcode;
191   DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
192          current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
193   WebSocketDeflatePredictor::Result prediction =
194       predictor_->Predict(frames, index);
195 
196   switch (prediction) {
197     case WebSocketDeflatePredictor::DEFLATE:
198       writing_state_ = WRITING_COMPRESSED_MESSAGE;
199       return;
200     case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
201       writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
202       return;
203     case WebSocketDeflatePredictor::TRY_DEFLATE:
204       writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
205       return;
206   }
207   NOTREACHED();
208 }
209 
AppendCompressedFrame(const WebSocketFrameHeader & header,std::vector<std::unique_ptr<WebSocketFrame>> * frames_to_write)210 int WebSocketDeflateStream::AppendCompressedFrame(
211     const WebSocketFrameHeader& header,
212     std::vector<std::unique_ptr<WebSocketFrame>>* frames_to_write) {
213   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
214   scoped_refptr<IOBufferWithSize> compressed_payload =
215       deflater_.GetOutput(deflater_.CurrentOutputSize());
216   if (!compressed_payload.get()) {
217     DVLOG(1) << "WebSocket protocol error. "
218              << "deflater_.GetOutput() returns an error.";
219     return ERR_WS_PROTOCOL_ERROR;
220   }
221   deflater_outputs_.push_back(compressed_payload);
222   auto compressed = std::make_unique<WebSocketFrame>(opcode);
223   compressed->header.CopyFrom(header);
224   compressed->header.opcode = opcode;
225   compressed->header.final = header.final;
226   compressed->header.reserved1 =
227       (opcode != WebSocketFrameHeader::kOpCodeContinuation);
228   compressed->payload = compressed_payload->span();
229   compressed->header.payload_length = compressed_payload->size();
230 
231   current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
232   predictor_->RecordWrittenDataFrame(compressed.get());
233   frames_to_write->push_back(std::move(compressed));
234   return OK;
235 }
236 
AppendPossiblyCompressedMessage(std::vector<std::unique_ptr<WebSocketFrame>> * frames,std::vector<std::unique_ptr<WebSocketFrame>> * frames_to_write)237 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
238     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
239     std::vector<std::unique_ptr<WebSocketFrame>>* frames_to_write) {
240   DCHECK(!frames->empty());
241 
242   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
243   scoped_refptr<IOBufferWithSize> compressed_payload =
244       deflater_.GetOutput(deflater_.CurrentOutputSize());
245   if (!compressed_payload.get()) {
246     DVLOG(1) << "WebSocket protocol error. "
247              << "deflater_.GetOutput() returns an error.";
248     return ERR_WS_PROTOCOL_ERROR;
249   }
250   deflater_outputs_.push_back(compressed_payload);
251 
252   uint64_t original_payload_length = 0;
253   for (size_t i = 0; i < frames->size(); ++i) {
254     WebSocketFrame* frame = (*frames)[i].get();
255     // Asserts checking that frames represent one whole data message.
256     DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
257     DCHECK_EQ(i == 0,
258               WebSocketFrameHeader::kOpCodeContinuation !=
259               frame->header.opcode);
260     DCHECK_EQ(i == frames->size() - 1, frame->header.final);
261     original_payload_length += frame->header.payload_length;
262   }
263   if (original_payload_length <=
264       static_cast<uint64_t>(compressed_payload->size())) {
265     // Compression is not effective. Use the original frames.
266     for (auto& frame : *frames) {
267       predictor_->RecordWrittenDataFrame(frame.get());
268       frames_to_write->push_back(std::move(frame));
269     }
270     frames->clear();
271     return OK;
272   }
273   auto compressed = std::make_unique<WebSocketFrame>(opcode);
274   compressed->header.CopyFrom((*frames)[0]->header);
275   compressed->header.opcode = opcode;
276   compressed->header.final = true;
277   compressed->header.reserved1 = true;
278   compressed->payload = compressed_payload->span();
279   compressed->header.payload_length = compressed_payload->size();
280 
281   predictor_->RecordWrittenDataFrame(compressed.get());
282   frames_to_write->push_back(std::move(compressed));
283   return OK;
284 }
285 
Inflate(std::vector<std::unique_ptr<WebSocketFrame>> * frames)286 int WebSocketDeflateStream::Inflate(
287     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
288   std::vector<std::unique_ptr<WebSocketFrame>> frames_to_output;
289   std::vector<std::unique_ptr<WebSocketFrame>> frames_passed;
290   frames->swap(frames_passed);
291   for (auto& frame_passed : frames_passed) {
292     std::unique_ptr<WebSocketFrame> frame(std::move(frame_passed));
293     frame_passed = nullptr;
294     DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
295              << " final=" << frame->header.final
296              << " reserved1=" << frame->header.reserved1
297              << " payload_length=" << frame->header.payload_length;
298 
299     if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
300       frames_to_output.push_back(std::move(frame));
301       continue;
302     }
303 
304     if (reading_state_ == NOT_READING) {
305       if (frame->header.reserved1)
306         reading_state_ = READING_COMPRESSED_MESSAGE;
307       else
308         reading_state_ = READING_UNCOMPRESSED_MESSAGE;
309       current_reading_opcode_ = frame->header.opcode;
310     } else {
311       if (frame->header.reserved1) {
312         DVLOG(1) << "WebSocket protocol error. "
313                  << "Receiving a non-first frame with RSV1 flag set.";
314         return ERR_WS_PROTOCOL_ERROR;
315       }
316     }
317 
318     if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
319       if (frame->header.final)
320         reading_state_ = NOT_READING;
321       current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
322       frames_to_output.push_back(std::move(frame));
323     } else {
324       DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
325       if (!frame->payload.empty() &&
326           !inflater_.AddBytes(base::as_chars(frame->payload).data(),
327                               frame->payload.size())) {
328         DVLOG(1) << "WebSocket protocol error. "
329                  << "inflater_.AddBytes() returns an error.";
330         return ERR_WS_PROTOCOL_ERROR;
331       }
332       if (frame->header.final) {
333         if (!inflater_.Finish()) {
334           DVLOG(1) << "WebSocket protocol error. "
335                    << "inflater_.Finish() returns an error.";
336           return ERR_WS_PROTOCOL_ERROR;
337         }
338       }
339       // TODO(yhirano): Many frames can be generated by the inflater and
340       // memory consumption can grow.
341       // We could avoid it, but avoiding it makes this class much more
342       // complicated.
343       while (inflater_.CurrentOutputSize() >= kChunkSize ||
344              frame->header.final) {
345         size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
346         auto inflated =
347             std::make_unique<WebSocketFrame>(WebSocketFrameHeader::kOpCodeText);
348         scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
349         inflater_outputs_.push_back(data);
350         bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
351         if (!data.get()) {
352           DVLOG(1) << "WebSocket protocol error. "
353                    << "inflater_.GetOutput() returns an error.";
354           return ERR_WS_PROTOCOL_ERROR;
355         }
356         inflated->header.CopyFrom(frame->header);
357         inflated->header.opcode = current_reading_opcode_;
358         inflated->header.final = is_final;
359         inflated->header.reserved1 = false;
360         inflated->payload = data->span();
361         inflated->header.payload_length = data->size();
362         DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
363                  << " final=" << inflated->header.final
364                  << " reserved1=" << inflated->header.reserved1
365                  << " payload_length=" << inflated->header.payload_length;
366         frames_to_output.push_back(std::move(inflated));
367         current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
368         if (is_final)
369           break;
370       }
371       if (frame->header.final)
372         reading_state_ = NOT_READING;
373     }
374   }
375   frames->swap(frames_to_output);
376   return frames->empty() ? ERR_IO_PENDING : OK;
377 }
378 
InflateAndReadIfNecessary(std::vector<std::unique_ptr<WebSocketFrame>> * frames)379 int WebSocketDeflateStream::InflateAndReadIfNecessary(
380     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
381   int result = Inflate(frames);
382   while (result == ERR_IO_PENDING) {
383     DCHECK(frames->empty());
384 
385     result = stream_->ReadFrames(
386         frames,
387         base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
388                        base::Unretained(this), base::Unretained(frames)));
389     if (result < 0)
390       break;
391     DCHECK_EQ(OK, result);
392     DCHECK(!frames->empty());
393 
394     result = Inflate(frames);
395   }
396   if (result < 0)
397     frames->clear();
398   return result;
399 }
400 
401 }  // namespace net
402