1 package fi.iki.elonen; 2 3 import java.util.HashMap; 4 import java.util.Map; 5 6 import org.junit.Before; 7 import org.junit.Test; 8 9 import fi.iki.elonen.NanoHTTPD.IHTTPSession; 10 import fi.iki.elonen.NanoHTTPD.Response; 11 import org.junit.runner.RunWith; 12 import org.mockito.ArgumentCaptor; 13 import org.mockito.Captor; 14 import org.mockito.Mock; 15 import org.mockito.runners.MockitoJUnitRunner; 16 17 import static junit.framework.Assert.*; 18 import static org.mockito.Matchers.any; 19 import static org.mockito.Mockito.atLeast; 20 import static org.mockito.Mockito.verify; 21 import static org.mockito.Mockito.when; 22 23 @RunWith(MockitoJUnitRunner.class) 24 public class WebSocketResponseHandlerTest { 25 26 @Mock 27 private IHTTPSession session; 28 @Mock 29 private WebSocket webSocket; 30 @Mock 31 private IWebSocketFactory webSocketFactory; 32 @Mock 33 private Response response; 34 @Captor 35 private ArgumentCaptor<String> headerNameCaptor; 36 @Captor 37 private ArgumentCaptor<String> headerCaptor; 38 39 private Map<String, String> headers; 40 41 private WebSocketResponseHandler responseHandler; 42 43 @Before setUp()44 public void setUp() { 45 headers = new HashMap<String, String>(); 46 headers.put("upgrade", "websocket"); 47 headers.put("connection", "Upgrade"); 48 headers.put("sec-websocket-key", "x3JJHMbDL1EzLkh9GBhXDw=="); 49 headers.put("sec-websocket-protocol", "chat, superchat"); 50 headers.put("sec-websocket-version", "13"); 51 52 when(session.getHeaders()).thenReturn(headers); 53 when(webSocketFactory.openWebSocket(any(IHTTPSession.class))).thenReturn(webSocket); 54 when(webSocket.getHandshakeResponse()).thenReturn(response); 55 56 responseHandler = new WebSocketResponseHandler(webSocketFactory); 57 } 58 59 @Test testHandshakeReturnsResponseWithExpectedHeaders()60 public void testHandshakeReturnsResponseWithExpectedHeaders() { 61 Response handshakeResponse = responseHandler.serve(session); 62 63 verify(webSocket).getHandshakeResponse(); 64 assertNotNull(handshakeResponse); 65 assertSame(response, handshakeResponse); 66 67 verify(response, atLeast(1)).addHeader(headerNameCaptor.capture(), headerCaptor.capture()); 68 assertHeader(0, "sec-websocket-accept", "HSmrc0sMlYUkAGmm5OPpG2HaGWk="); 69 assertHeader(1, "sec-websocket-protocol", "chat"); 70 } 71 72 @Test testWrongWebsocketVersionReturnsErrorResponse()73 public void testWrongWebsocketVersionReturnsErrorResponse() { 74 headers.put("sec-websocket-version", "12"); 75 76 Response handshakeResponse = responseHandler.serve(session); 77 78 assertNotNull(handshakeResponse); 79 assertEquals(Response.Status.BAD_REQUEST, handshakeResponse.getStatus()); 80 } 81 82 @Test testMissingKeyReturnsErrorResponse()83 public void testMissingKeyReturnsErrorResponse() { 84 headers.remove("sec-websocket-key"); 85 86 Response handshakeResponse = responseHandler.serve(session); 87 88 assertNotNull(handshakeResponse); 89 assertEquals(Response.Status.BAD_REQUEST, handshakeResponse.getStatus()); 90 } 91 92 @Test testWrongUpgradeHeaderReturnsNullResponse()93 public void testWrongUpgradeHeaderReturnsNullResponse() { 94 headers.put("upgrade", "not a websocket"); 95 Response handshakeResponse = responseHandler.serve(session); 96 assertNull(handshakeResponse); 97 } 98 99 @Test testWrongConnectionHeaderReturnsNullResponse()100 public void testWrongConnectionHeaderReturnsNullResponse() { 101 headers.put("connection", "Junk"); 102 Response handshakeResponse = responseHandler.serve(session); 103 assertNull(handshakeResponse); 104 } 105 106 @Test testConnectionHeaderHandlesKeepAlive_FixingFirefoxConnectIssue()107 public void testConnectionHeaderHandlesKeepAlive_FixingFirefoxConnectIssue() { 108 headers.put("connection", "keep-alive, Upgrade"); 109 Response handshakeResponse = responseHandler.serve(session); 110 111 verify(webSocket).getHandshakeResponse(); 112 assertNotNull(handshakeResponse); 113 assertSame(response, handshakeResponse); 114 } 115 assertHeader(int index, String name, String value)116 private void assertHeader(int index, String name, String value) { 117 assertEquals(name, headerNameCaptor.getAllValues().get(index)); 118 assertEquals(value, headerCaptor.getAllValues().get(index)); 119 } 120 } 121