• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2014 Square, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.squareup.okhttp.internal.ws;
17 
18 import com.squareup.okhttp.MediaType;
19 import com.squareup.okhttp.ResponseBody;
20 import com.squareup.okhttp.ws.WebSocket;
21 import java.io.EOFException;
22 import java.io.IOException;
23 import java.net.ProtocolException;
24 import okio.Buffer;
25 import okio.BufferedSource;
26 import okio.Okio;
27 import okio.Source;
28 import okio.Timeout;
29 
30 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_FIN;
31 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV1;
32 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV2;
33 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV3;
34 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_MASK_OPCODE;
35 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_FLAG_MASK;
36 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_MASK_LENGTH;
37 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_BINARY;
38 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTINUATION;
39 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_CLOSE;
40 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PING;
41 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PONG;
42 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_FLAG_CONTROL;
43 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_TEXT;
44 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_LONG;
45 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_BYTE_MAX;
46 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_SHORT;
47 import static com.squareup.okhttp.internal.ws.WebSocketProtocol.toggleMask;
48 import static java.lang.Integer.toHexString;
49 
50 /**
51  * An <a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a>-compatible WebSocket frame reader.
52  */
53 public final class WebSocketReader {
54   public interface FrameCallback {
onMessage(ResponseBody body)55     void onMessage(ResponseBody body) throws IOException;
onPing(Buffer buffer)56     void onPing(Buffer buffer);
onPong(Buffer buffer)57     void onPong(Buffer buffer);
onClose(int code, String reason)58     void onClose(int code, String reason);
59   }
60 
61   private final boolean isClient;
62   private final BufferedSource source;
63   private final FrameCallback frameCallback;
64 
65   private final Source framedMessageSource = new FramedMessageSource();
66 
67   private boolean closed;
68   private boolean messageClosed;
69 
70   // Stateful data about the current frame.
71   private int opcode;
72   private long frameLength;
73   private long frameBytesRead;
74   private boolean isFinalFrame;
75   private boolean isControlFrame;
76   private boolean isMasked;
77 
78   private final byte[] maskKey = new byte[4];
79   private final byte[] maskBuffer = new byte[2048];
80 
WebSocketReader(boolean isClient, BufferedSource source, FrameCallback frameCallback)81   public WebSocketReader(boolean isClient, BufferedSource source, FrameCallback frameCallback) {
82     if (source == null) throw new NullPointerException("source == null");
83     if (frameCallback == null) throw new NullPointerException("frameCallback == null");
84     this.isClient = isClient;
85     this.source = source;
86     this.frameCallback = frameCallback;
87   }
88 
89   /**
90    * Process the next protocol frame.
91    * <ul>
92    * <li>If it is a control frame this will result in a single call to {@link FrameCallback}.</li>
93    * <li>If it is a message frame this will result in a single call to {@link
94    * FrameCallback#onMessage}. If the message spans multiple frames, each interleaved control
95    * frame will result in a corresponding call to {@link FrameCallback}.
96    * </ul>
97    */
processNextFrame()98   public void processNextFrame() throws IOException {
99     readHeader();
100     if (isControlFrame) {
101       readControlFrame();
102     } else {
103       readMessageFrame();
104     }
105   }
106 
readHeader()107   private void readHeader() throws IOException {
108     if (closed) throw new IOException("closed");
109 
110     int b0 = source.readByte() & 0xff;
111 
112     opcode = b0 & B0_MASK_OPCODE;
113     isFinalFrame = (b0 & B0_FLAG_FIN) != 0;
114     isControlFrame = (b0 & OPCODE_FLAG_CONTROL) != 0;
115 
116     // Control frames must be final frames (cannot contain continuations).
117     if (isControlFrame && !isFinalFrame) {
118       throw new ProtocolException("Control frames must be final.");
119     }
120 
121     boolean reservedFlag1 = (b0 & B0_FLAG_RSV1) != 0;
122     boolean reservedFlag2 = (b0 & B0_FLAG_RSV2) != 0;
123     boolean reservedFlag3 = (b0 & B0_FLAG_RSV3) != 0;
124     if (reservedFlag1 || reservedFlag2 || reservedFlag3) {
125       // Reserved flags are for extensions which we currently do not support.
126       throw new ProtocolException("Reserved flags are unsupported.");
127     }
128 
129     int b1 = source.readByte() & 0xff;
130 
131     isMasked = (b1 & B1_FLAG_MASK) != 0;
132     if (isMasked == isClient) {
133       // Masked payloads must be read on the server. Unmasked payloads must be read on the client.
134       throw new ProtocolException("Client-sent frames must be masked. Server sent must not.");
135     }
136 
137     // Get frame length, optionally reading from follow-up bytes if indicated by special values.
138     frameLength = b1 & B1_MASK_LENGTH;
139     if (frameLength == PAYLOAD_SHORT) {
140       frameLength = source.readShort() & 0xffffL; // Value is unsigned.
141     } else if (frameLength == PAYLOAD_LONG) {
142       frameLength = source.readLong();
143       if (frameLength < 0) {
144         throw new ProtocolException(
145             "Frame length 0x" + Long.toHexString(frameLength) + " > 0x7FFFFFFFFFFFFFFF");
146       }
147     }
148     frameBytesRead = 0;
149 
150     if (isControlFrame && frameLength > PAYLOAD_BYTE_MAX) {
151       throw new ProtocolException("Control frame must be less than " + PAYLOAD_BYTE_MAX + "B.");
152     }
153 
154     if (isMasked) {
155       // Read the masking key as bytes so that they can be used directly for unmasking.
156       source.readFully(maskKey);
157     }
158   }
159 
readControlFrame()160   private void readControlFrame() throws IOException {
161     Buffer buffer = null;
162     if (frameBytesRead < frameLength) {
163       buffer = new Buffer();
164 
165       if (isClient) {
166         source.readFully(buffer, frameLength);
167       } else {
168         while (frameBytesRead < frameLength) {
169           int toRead = (int) Math.min(frameLength - frameBytesRead, maskBuffer.length);
170           int read = source.read(maskBuffer, 0, toRead);
171           if (read == -1) throw new EOFException();
172           toggleMask(maskBuffer, read, maskKey, frameBytesRead);
173           buffer.write(maskBuffer, 0, read);
174           frameBytesRead += read;
175         }
176       }
177     }
178 
179     switch (opcode) {
180       case OPCODE_CONTROL_PING:
181         frameCallback.onPing(buffer);
182         break;
183       case OPCODE_CONTROL_PONG:
184         frameCallback.onPong(buffer);
185         break;
186       case OPCODE_CONTROL_CLOSE:
187         int code = 1000;
188         String reason = "";
189         if (buffer != null) {
190           long bufferSize = buffer.size();
191           if (bufferSize == 1) {
192             throw new ProtocolException("Malformed close payload length of 1.");
193           } else if (bufferSize != 0) {
194             code = buffer.readShort();
195             if (code < 1000 || code >= 5000) {
196               throw new ProtocolException("Code must be in range [1000,5000): " + code);
197             }
198             if ((code >= 1004 && code <= 1006) || (code >= 1012 && code <= 2999)) {
199               throw new ProtocolException("Code " + code + " is reserved and may not be used.");
200             }
201 
202             reason = buffer.readUtf8();
203           }
204         }
205         frameCallback.onClose(code, reason);
206         closed = true;
207         break;
208       default:
209         throw new ProtocolException("Unknown control opcode: " + toHexString(opcode));
210     }
211   }
212 
readMessageFrame()213   private void readMessageFrame() throws IOException {
214     final MediaType type;
215     switch (opcode) {
216       case OPCODE_TEXT:
217         type = WebSocket.TEXT;
218         break;
219       case OPCODE_BINARY:
220         type = WebSocket.BINARY;
221         break;
222       default:
223         throw new ProtocolException("Unknown opcode: " + toHexString(opcode));
224     }
225 
226     final BufferedSource source = Okio.buffer(framedMessageSource);
227     ResponseBody body = new ResponseBody() {
228       @Override public MediaType contentType() {
229         return type;
230       }
231 
232       @Override public long contentLength() throws IOException {
233         return -1;
234       }
235 
236       @Override public BufferedSource source() throws IOException {
237         return source;
238       }
239     };
240 
241     messageClosed = false;
242     frameCallback.onMessage(body);
243     if (!messageClosed) {
244       throw new IllegalStateException("Listener failed to call close on message payload.");
245     }
246   }
247 
248   /** Read headers and process any control frames until we reach a non-control frame. */
readUntilNonControlFrame()249   private void readUntilNonControlFrame() throws IOException {
250     while (!closed) {
251       readHeader();
252       if (!isControlFrame) {
253         break;
254       }
255       readControlFrame();
256     }
257   }
258 
259   /**
260    * A special source which knows how to read a message body across one or more frames. Control
261    * frames that occur between fragments will be processed. If the message payload is masked this
262    * will unmask as it's being processed.
263    */
264   private final class FramedMessageSource implements Source {
read(Buffer sink, long byteCount)265     @Override public long read(Buffer sink, long byteCount) throws IOException {
266       if (closed) throw new IOException("closed");
267       if (messageClosed) throw new IllegalStateException("closed");
268 
269       if (frameBytesRead == frameLength) {
270         if (isFinalFrame) return -1; // We are exhausted and have no continuations.
271 
272         readUntilNonControlFrame();
273         if (opcode != OPCODE_CONTINUATION) {
274           throw new ProtocolException("Expected continuation opcode. Got: " + toHexString(opcode));
275         }
276         if (isFinalFrame && frameLength == 0) {
277           return -1; // Fast-path for empty final frame.
278         }
279       }
280 
281       long toRead = Math.min(byteCount, frameLength - frameBytesRead);
282 
283       long read;
284       if (isMasked) {
285         toRead = Math.min(toRead, maskBuffer.length);
286         read = source.read(maskBuffer, 0, (int) toRead);
287         if (read == -1) throw new EOFException();
288         toggleMask(maskBuffer, read, maskKey, frameBytesRead);
289         sink.write(maskBuffer, 0, (int) read);
290       } else {
291         read = source.read(sink, toRead);
292         if (read == -1) throw new EOFException();
293       }
294 
295       frameBytesRead += read;
296       return read;
297     }
298 
timeout()299     @Override public Timeout timeout() {
300       return source.timeout();
301     }
302 
close()303     @Override public void close() throws IOException {
304       if (messageClosed) return;
305       messageClosed = true;
306       if (closed) return;
307 
308       // Exhaust the remainder of the message, if any.
309       source.skip(frameLength - frameBytesRead);
310       while (!isFinalFrame) {
311         readUntilNonControlFrame();
312         source.skip(frameLength);
313       }
314     }
315   }
316 }
317