• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2014 Square, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.squareup.okhttp.ws;
17 
18 import com.squareup.okhttp.OkHttpClient;
19 import com.squareup.okhttp.Request;
20 import com.squareup.okhttp.Response;
21 import com.squareup.okhttp.internal.SslContextBuilder;
22 import com.squareup.okhttp.mockwebserver.MockResponse;
23 import com.squareup.okhttp.mockwebserver.MockWebServer;
24 import com.squareup.okhttp.testing.RecordingHostnameVerifier;
25 import java.io.IOException;
26 import java.net.ProtocolException;
27 import java.util.Random;
28 import java.util.concurrent.CountDownLatch;
29 import java.util.concurrent.TimeUnit;
30 import java.util.concurrent.atomic.AtomicReference;
31 import javax.net.ssl.SSLContext;
32 import okio.Buffer;
33 import okio.BufferedSink;
34 import okio.BufferedSource;
35 import org.junit.After;
36 import org.junit.Rule;
37 import org.junit.Test;
38 
39 import static com.squareup.okhttp.ws.WebSocket.PayloadType.TEXT;
40 
41 public final class WebSocketCallTest {
42   @Rule public final MockWebServer server = new MockWebServer();
43 
44   private final SSLContext sslContext = SslContextBuilder.localhost();
45   private final WebSocketRecorder listener = new WebSocketRecorder();
46   private final OkHttpClient client = new OkHttpClient();
47   private final Random random = new Random(0);
48 
tearDown()49   @After public void tearDown() {
50     listener.assertExhausted();
51   }
52 
clientPingPong()53   @Test public void clientPingPong() throws IOException {
54     WebSocketListener serverListener = new EmptyWebSocketListener();
55     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
56 
57     WebSocket webSocket = awaitWebSocket();
58     webSocket.sendPing(new Buffer().writeUtf8("Hello, WebSockets!"));
59     listener.assertPong(new Buffer().writeUtf8("Hello, WebSockets!"));
60   }
61 
clientMessage()62   @Test public void clientMessage() throws IOException {
63     WebSocketRecorder serverListener = new WebSocketRecorder();
64     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
65 
66     WebSocket webSocket = awaitWebSocket();
67     webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!"));
68     serverListener.assertTextMessage("Hello, WebSockets!");
69   }
70 
serverMessage()71   @Test public void serverMessage() throws IOException {
72     WebSocketListener serverListener = new EmptyWebSocketListener() {
73       @Override public void onOpen(final WebSocket webSocket, Response response) {
74         new Thread() {
75           @Override public void run() {
76             try {
77               webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!"));
78             } catch (IOException e) {
79               throw new AssertionError(e);
80             }
81           }
82         }.start();
83       }
84     };
85     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
86 
87     awaitWebSocket();
88     listener.assertTextMessage("Hello, WebSockets!");
89   }
90 
clientStreamingMessage()91   @Test public void clientStreamingMessage() throws IOException {
92     WebSocketRecorder serverListener = new WebSocketRecorder();
93     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
94 
95     WebSocket webSocket = awaitWebSocket();
96     BufferedSink sink = webSocket.newMessageSink(TEXT);
97     sink.writeUtf8("Hello, ").flush();
98     sink.writeUtf8("WebSockets!").flush();
99     sink.close();
100 
101     serverListener.assertTextMessage("Hello, WebSockets!");
102   }
103 
serverStreamingMessage()104   @Test public void serverStreamingMessage() throws IOException {
105     WebSocketListener serverListener = new EmptyWebSocketListener() {
106       @Override public void onOpen(final WebSocket webSocket, Response response) {
107         new Thread() {
108           @Override public void run() {
109             try {
110               BufferedSink sink = webSocket.newMessageSink(TEXT);
111               sink.writeUtf8("Hello, ").flush();
112               sink.writeUtf8("WebSockets!").flush();
113               sink.close();
114             } catch (IOException e) {
115               throw new AssertionError(e);
116             }
117           }
118         }.start();
119       }
120     };
121     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
122 
123     awaitWebSocket();
124     listener.assertTextMessage("Hello, WebSockets!");
125   }
126 
okButNotOk()127   @Test public void okButNotOk() {
128     server.enqueue(new MockResponse());
129     awaitWebSocket();
130     listener.assertFailure(ProtocolException.class, "Expected HTTP 101 response but was '200 OK'");
131   }
132 
notFound()133   @Test public void notFound() {
134     server.enqueue(new MockResponse().setStatus("HTTP/1.1 404 Not Found"));
135     awaitWebSocket();
136     listener.assertFailure(ProtocolException.class,
137         "Expected HTTP 101 response but was '404 Not Found'");
138   }
139 
missingConnectionHeader()140   @Test public void missingConnectionHeader() {
141     server.enqueue(new MockResponse()
142         .setResponseCode(101)
143         .setHeader("Upgrade", "websocket")
144         .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
145     awaitWebSocket();
146     listener.assertFailure(ProtocolException.class,
147         "Expected 'Connection' header value 'Upgrade' but was 'null'");
148   }
149 
wrongConnectionHeader()150   @Test public void wrongConnectionHeader() {
151     server.enqueue(new MockResponse()
152         .setResponseCode(101)
153         .setHeader("Upgrade", "websocket")
154         .setHeader("Connection", "Downgrade")
155         .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
156     awaitWebSocket();
157     listener.assertFailure(ProtocolException.class,
158         "Expected 'Connection' header value 'Upgrade' but was 'Downgrade'");
159   }
160 
missingUpgradeHeader()161   @Test public void missingUpgradeHeader() {
162     server.enqueue(new MockResponse()
163         .setResponseCode(101)
164         .setHeader("Connection", "Upgrade")
165         .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
166     awaitWebSocket();
167     listener.assertFailure(ProtocolException.class,
168         "Expected 'Upgrade' header value 'websocket' but was 'null'");
169   }
170 
wrongUpgradeHeader()171   @Test public void wrongUpgradeHeader() {
172     server.enqueue(new MockResponse()
173         .setResponseCode(101)
174         .setHeader("Connection", "Upgrade")
175         .setHeader("Upgrade", "Pepsi")
176         .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
177     awaitWebSocket();
178     listener.assertFailure(ProtocolException.class,
179         "Expected 'Upgrade' header value 'websocket' but was 'Pepsi'");
180   }
181 
missingMagicHeader()182   @Test public void missingMagicHeader() {
183     server.enqueue(new MockResponse()
184         .setResponseCode(101)
185         .setHeader("Connection", "Upgrade")
186         .setHeader("Upgrade", "websocket"));
187     awaitWebSocket();
188     listener.assertFailure(ProtocolException.class,
189         "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'");
190   }
191 
wrongMagicHeader()192   @Test public void wrongMagicHeader() {
193     server.enqueue(new MockResponse()
194         .setResponseCode(101)
195         .setHeader("Connection", "Upgrade")
196         .setHeader("Upgrade", "websocket")
197         .setHeader("Sec-WebSocket-Accept", "magic"));
198     awaitWebSocket();
199     listener.assertFailure(ProtocolException.class,
200         "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'");
201   }
202 
wsScheme()203   @Test public void wsScheme() throws IOException {
204     websocketScheme("ws");
205   }
206 
wsUppercaseScheme()207   @Test public void wsUppercaseScheme() throws IOException {
208     websocketScheme("WS");
209   }
210 
wssScheme()211   @Test public void wssScheme() throws IOException {
212     server.useHttps(sslContext.getSocketFactory(), false);
213     client.setSslSocketFactory(sslContext.getSocketFactory());
214     client.setHostnameVerifier(new RecordingHostnameVerifier());
215 
216     websocketScheme("wss");
217   }
218 
httpsScheme()219   @Test public void httpsScheme() throws IOException {
220     server.useHttps(sslContext.getSocketFactory(), false);
221     client.setSslSocketFactory(sslContext.getSocketFactory());
222     client.setHostnameVerifier(new RecordingHostnameVerifier());
223 
224     websocketScheme("https");
225   }
226 
websocketScheme(String scheme)227   private void websocketScheme(String scheme) throws IOException {
228     WebSocketRecorder serverListener = new WebSocketRecorder();
229     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
230 
231     Request request1 = new Request.Builder()
232         .url(scheme + "://" + server.getHostName() + ":" + server.getPort() + "/")
233         .build();
234 
235     WebSocket webSocket = awaitWebSocket(request1);
236     webSocket.sendMessage(TEXT, new Buffer().writeUtf8("abc"));
237     serverListener.assertTextMessage("abc");
238   }
239 
awaitWebSocket()240   private WebSocket awaitWebSocket() {
241     return awaitWebSocket(new Request.Builder().get().url(server.url("/")).build());
242   }
243 
awaitWebSocket(Request request)244   private WebSocket awaitWebSocket(Request request) {
245     WebSocketCall call = new WebSocketCall(client, request, random);
246 
247     final AtomicReference<Response> responseRef = new AtomicReference<>();
248     final AtomicReference<WebSocket> webSocketRef = new AtomicReference<>();
249     final AtomicReference<IOException> failureRef = new AtomicReference<>();
250     final CountDownLatch latch = new CountDownLatch(1);
251     call.enqueue(new WebSocketListener() {
252       @Override public void onOpen(WebSocket webSocket, Response response) {
253         webSocketRef.set(webSocket);
254         responseRef.set(response);
255         latch.countDown();
256       }
257 
258       @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type)
259           throws IOException {
260         listener.onMessage(payload, type);
261       }
262 
263       @Override public void onPong(Buffer payload) {
264         listener.onPong(payload);
265       }
266 
267       @Override public void onClose(int code, String reason) {
268         listener.onClose(code, reason);
269       }
270 
271       @Override public void onFailure(IOException e, Response response) {
272         listener.onFailure(e, null);
273         failureRef.set(e);
274         latch.countDown();
275       }
276     });
277 
278     try {
279       if (!latch.await(10, TimeUnit.SECONDS)) {
280         throw new AssertionError("Timed out.");
281       }
282     } catch (InterruptedException e) {
283       throw new AssertionError(e);
284     }
285 
286     return webSocketRef.get();
287   }
288 
289   private static class EmptyWebSocketListener implements WebSocketListener {
onOpen(WebSocket webSocket, Response response)290     @Override public void onOpen(WebSocket webSocket, Response response) {
291     }
292 
onMessage(BufferedSource payload, WebSocket.PayloadType type)293     @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type)
294         throws IOException {
295     }
296 
onPong(Buffer payload)297     @Override public void onPong(Buffer payload) {
298     }
299 
onClose(int code, String reason)300     @Override public void onClose(int code, String reason) {
301     }
302 
onFailure(IOException e, Response response)303     @Override public void onFailure(IOException e, Response response) {
304     }
305   }
306 }
307