1 /*
2 * Copyright (C) 2013 Google Inc. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions are
6 * met:
7 *
8 * * Redistributions of source code must retain the above copyright
9 * notice, this list of conditions and the following disclaimer.
10 * * Redistributions in binary form must reproduce the above
11 * copyright notice, this list of conditions and the following disclaimer
12 * in the documentation and/or other materials provided with the
13 * distribution.
14 * * Neither the name of Google Inc. nor the names of its
15 * contributors may be used to endorse or promote products derived from
16 * this software without specific prior written permission.
17 *
18 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 */
30
31 #include "config.h"
32
33 #include "modules/websockets/WebSocketPerMessageDeflate.h"
34
35 #include "modules/websockets/WebSocketExtensionParser.h"
36 #include "public/platform/Platform.h"
37 #include "wtf/HashMap.h"
38 #include "wtf/text/CString.h"
39 #include "wtf/text/StringHash.h"
40 #include "wtf/text/WTFString.h"
41
42 namespace blink {
43
44 class CompressionMessageExtensionProcessor FINAL : public WebSocketExtensionProcessor {
45 WTF_MAKE_FAST_ALLOCATED;
46 WTF_MAKE_NONCOPYABLE(CompressionMessageExtensionProcessor);
47 public:
create(WebSocketPerMessageDeflate & compress)48 static PassOwnPtr<CompressionMessageExtensionProcessor> create(WebSocketPerMessageDeflate& compress)
49 {
50 return adoptPtr(new CompressionMessageExtensionProcessor(compress));
51 }
~CompressionMessageExtensionProcessor()52 virtual ~CompressionMessageExtensionProcessor() { }
53
54 virtual String handshakeString() OVERRIDE;
55 virtual bool processResponse(const HashMap<String, String>&) OVERRIDE;
failureReason()56 virtual String failureReason() OVERRIDE { return m_failureReason; }
57
58 private:
59 explicit CompressionMessageExtensionProcessor(WebSocketPerMessageDeflate&);
60
61 WebSocketPerMessageDeflate& m_compress;
62 bool m_responseProcessed;
63 String m_failureReason;
64 };
65
CompressionMessageExtensionProcessor(WebSocketPerMessageDeflate & compress)66 CompressionMessageExtensionProcessor::CompressionMessageExtensionProcessor(WebSocketPerMessageDeflate& compress)
67 : WebSocketExtensionProcessor("permessage-deflate")
68 , m_compress(compress)
69 , m_responseProcessed(false)
70 {
71 }
72
handshakeString()73 String CompressionMessageExtensionProcessor::handshakeString()
74 {
75 return "permessage-deflate; client_max_window_bits";
76 }
77
processResponse(const HashMap<String,String> & parameters)78 bool CompressionMessageExtensionProcessor::processResponse(const HashMap<String, String>& parameters)
79 {
80 unsigned numProcessedParameters = 0;
81 if (m_responseProcessed) {
82 m_failureReason = "Received duplicate permessage-deflate response";
83 return false;
84 }
85 m_responseProcessed = true;
86 WebSocketDeflater::ContextTakeOverMode mode = WebSocketDeflater::TakeOverContext;
87 int windowBits = 15;
88
89 HashMap<String, String>::const_iterator clientNoContextTakeover = parameters.find("client_no_context_takeover");
90 HashMap<String, String>::const_iterator clientMaxWindowBits = parameters.find("client_max_window_bits");
91 HashMap<String, String>::const_iterator serverNoContextTakeover = parameters.find("server_no_context_takeover");
92 HashMap<String, String>::const_iterator serverMaxWindowBits = parameters.find("server_max_window_bits");
93
94 if (clientNoContextTakeover != parameters.end()) {
95 if (!clientNoContextTakeover->value.isNull()) {
96 m_failureReason = "Received invalid client_no_context_takeover parameter";
97 return false;
98 }
99 mode = WebSocketDeflater::DoNotTakeOverContext;
100 ++numProcessedParameters;
101 }
102 if (clientMaxWindowBits != parameters.end()) {
103 if (!clientMaxWindowBits->value.length()) {
104 m_failureReason = "client_max_window_bits parameter must have value";
105 return false;
106 }
107 bool ok = false;
108 windowBits = clientMaxWindowBits->value.toIntStrict(&ok);
109 if (!ok || windowBits < 8 || windowBits > 15 || clientMaxWindowBits->value[0] == '+' || clientMaxWindowBits->value[0] == '0') {
110 m_failureReason = "Received invalid client_max_window_bits parameter";
111 return false;
112 }
113 ++numProcessedParameters;
114 }
115 if (serverNoContextTakeover != parameters.end()) {
116 if (!serverNoContextTakeover->value.isNull()) {
117 m_failureReason = "Received invalid server_no_context_takeover parameter";
118 return false;
119 }
120 ++numProcessedParameters;
121 }
122 if (serverMaxWindowBits != parameters.end()) {
123 if (!serverMaxWindowBits->value.length()) {
124 m_failureReason = "server_max_window_bits parameter must have value";
125 return false;
126 }
127 bool ok = false;
128 int bits = serverMaxWindowBits->value.toIntStrict(&ok);
129 if (!ok || bits < 8 || bits > 15 || serverMaxWindowBits->value[0] == '+' || serverMaxWindowBits->value[0] == '0') {
130 m_failureReason = "Received invalid server_max_window_bits parameter";
131 return false;
132 }
133 ++numProcessedParameters;
134 }
135
136 if (numProcessedParameters != parameters.size()) {
137 m_failureReason = "Received an unexpected permessage-deflate extension parameter";
138 return false;
139 }
140 Platform::current()->histogramEnumeration("WebCore.WebSocket.PerMessageDeflateContextTakeOverMode", mode, WebSocketDeflater::ContextTakeOverModeMax);
141 m_compress.enable(windowBits, mode);
142 // Since we don't request server_no_context_takeover and server_max_window_bits, they should be ignored.
143 return true;
144 }
145
WebSocketPerMessageDeflate()146 WebSocketPerMessageDeflate::WebSocketPerMessageDeflate()
147 : m_enabled(false)
148 , m_deflateOngoing(false)
149 , m_inflateOngoing(false)
150 {
151 }
152
createExtensionProcessor()153 PassOwnPtr<WebSocketExtensionProcessor> WebSocketPerMessageDeflate::createExtensionProcessor()
154 {
155 return CompressionMessageExtensionProcessor::create(*this);
156 }
157
enable(int windowBits,WebSocketDeflater::ContextTakeOverMode mode)158 void WebSocketPerMessageDeflate::enable(int windowBits, WebSocketDeflater::ContextTakeOverMode mode)
159 {
160 m_deflater = WebSocketDeflater::create(windowBits, mode);
161 m_inflater = WebSocketInflater::create();
162 if (!m_deflater->initialize() || !m_inflater->initialize()) {
163 m_deflater.clear();
164 m_inflater.clear();
165 return;
166 }
167 m_enabled = true;
168 m_deflateOngoing = false;
169 m_inflateOngoing = false;
170 }
171
deflate(WebSocketFrame & frame)172 bool WebSocketPerMessageDeflate::deflate(WebSocketFrame& frame)
173 {
174 if (!enabled())
175 return true;
176 if (frame.compress) {
177 m_failureReason = "Some extension already uses the compress bit.";
178 return false;
179 }
180 if (!WebSocketFrame::isNonControlOpCode(frame.opCode))
181 return true;
182
183 if ((frame.opCode == WebSocketFrame::OpCodeText || frame.opCode == WebSocketFrame::OpCodeBinary)
184 && frame.final
185 && frame.payloadLength <= 2) {
186 // A trivial optimization: If a message consists of one frame and its
187 // payload length is very short, we don't compress it.
188 return true;
189 }
190
191 if (frame.payloadLength > 0 && !m_deflater->addBytes(frame.payload, frame.payloadLength)) {
192 m_failureReason = "Failed to deflate a frame";
193 return false;
194 }
195 if (frame.final && !m_deflater->finish()) {
196 m_failureReason = "Failed to finish compression";
197 return false;
198 }
199
200 frame.compress = !m_deflateOngoing;
201 frame.payload = m_deflater->data();
202 frame.payloadLength = m_deflater->size();
203 m_deflateOngoing = !frame.final;
204 return true;
205 }
206
resetDeflateBuffer()207 void WebSocketPerMessageDeflate::resetDeflateBuffer()
208 {
209 if (m_deflater) {
210 if (m_deflateOngoing)
211 m_deflater->softReset();
212 else
213 m_deflater->reset();
214 }
215 }
216
inflate(WebSocketFrame & frame)217 bool WebSocketPerMessageDeflate::inflate(WebSocketFrame& frame)
218 {
219 if (!enabled())
220 return true;
221 if (!WebSocketFrame::isNonControlOpCode(frame.opCode)) {
222 if (frame.compress) {
223 m_failureReason = "Received unexpected compressed frame";
224 return false;
225 }
226 return true;
227 }
228 if (frame.compress) {
229 if (m_inflateOngoing) {
230 m_failureReason = "Received a frame that sets compressed bit while another decompression is ongoing";
231 return false;
232 }
233 m_inflateOngoing = true;
234 }
235
236 if (!m_inflateOngoing)
237 return true;
238
239 if (frame.payloadLength > 0 && !m_inflater->addBytes(frame.payload, frame.payloadLength)) {
240 m_failureReason = "Failed to inflate a frame";
241 return false;
242 }
243 if (frame.final && !m_inflater->finish()) {
244 m_failureReason = "Failed to finish decompression";
245 return false;
246 }
247 frame.compress = false;
248 frame.payload = m_inflater->data();
249 frame.payloadLength = m_inflater->size();
250 m_inflateOngoing = !frame.final;
251 return true;
252 }
253
resetInflateBuffer()254 void WebSocketPerMessageDeflate::resetInflateBuffer()
255 {
256 if (m_inflater)
257 m_inflater->reset();
258 }
259
didFail()260 void WebSocketPerMessageDeflate::didFail()
261 {
262 resetDeflateBuffer();
263 resetInflateBuffer();
264 }
265
266 } // namespace blink
267