• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
6 
7 #include <string>
8 #include <vector>
9 
10 #include "net/base/completion_callback.h"
11 #include "net/base/net_errors.h"
12 #include "net/http/http_request_headers.h"
13 #include "net/http/http_request_info.h"
14 #include "net/http/http_response_headers.h"
15 #include "net/http/http_response_info.h"
16 #include "net/socket/client_socket_handle.h"
17 #include "net/socket/socket_test_util.h"
18 #include "net/websockets/websocket_basic_handshake_stream.h"
19 #include "net/websockets/websocket_stream.h"
20 #include "net/websockets/websocket_test_util.h"
21 #include "testing/gtest/include/gtest/gtest.h"
22 #include "url/gurl.h"
23 
24 namespace net {
25 namespace {
26 
27 // This class encapsulates the details of creating a mock ClientSocketHandle.
28 class MockClientSocketHandleFactory {
29  public:
MockClientSocketHandleFactory()30   MockClientSocketHandleFactory()
31       : histograms_("a"),
32         pool_(1, 1, &histograms_, socket_factory_maker_.factory()) {}
33 
34   // The created socket expects |expect_written| to be written to the socket,
35   // and will respond with |return_to_read|. The test will fail if the expected
36   // text is not written, or if all the bytes are not read.
CreateClientSocketHandle(const std::string & expect_written,const std::string & return_to_read)37   scoped_ptr<ClientSocketHandle> CreateClientSocketHandle(
38       const std::string& expect_written,
39       const std::string& return_to_read) {
40     socket_factory_maker_.SetExpectations(expect_written, return_to_read);
41     scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle);
42     socket_handle->Init(
43         "a",
44         scoped_refptr<MockTransportSocketParams>(),
45         MEDIUM,
46         CompletionCallback(),
47         &pool_,
48         BoundNetLog());
49     return socket_handle.Pass();
50   }
51 
52  private:
53   WebSocketDeterministicMockClientSocketFactoryMaker socket_factory_maker_;
54   ClientSocketPoolHistograms histograms_;
55   MockTransportClientSocketPool pool_;
56 
57   DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory);
58 };
59 
60 class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
61  public:
~TestConnectDelegate()62   virtual ~TestConnectDelegate() {}
63 
OnSuccess(scoped_ptr<WebSocketStream> stream)64   virtual void OnSuccess(scoped_ptr<WebSocketStream> stream) OVERRIDE {}
OnFailure(const std::string & failure_message)65   virtual void OnFailure(const std::string& failure_message) OVERRIDE {}
OnStartOpeningHandshake(scoped_ptr<WebSocketHandshakeRequestInfo> request)66   virtual void OnStartOpeningHandshake(
67       scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE {}
OnFinishOpeningHandshake(scoped_ptr<WebSocketHandshakeResponseInfo> response)68   virtual void OnFinishOpeningHandshake(
69       scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE {}
OnSSLCertificateError(scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,const SSLInfo & ssl_info,bool fatal)70   virtual void OnSSLCertificateError(
71       scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
72           ssl_error_callbacks,
73       const SSLInfo& ssl_info,
74       bool fatal) OVERRIDE {}
75 };
76 
77 class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test {
78  protected:
CreateAndInitializeStream(const std::string & socket_url,const std::string & socket_path,const std::vector<std::string> & sub_protocols,const std::string & origin,const std::string & extra_request_headers,const std::string & extra_response_headers)79   scoped_ptr<WebSocketStream> CreateAndInitializeStream(
80       const std::string& socket_url,
81       const std::string& socket_path,
82       const std::vector<std::string>& sub_protocols,
83       const std::string& origin,
84       const std::string& extra_request_headers,
85       const std::string& extra_response_headers) {
86     WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_,
87                                                        sub_protocols);
88     create_helper.set_failure_message(&failure_message_);
89 
90     scoped_ptr<ClientSocketHandle> socket_handle =
91         socket_handle_factory_.CreateClientSocketHandle(
92             WebSocketStandardRequest(
93                 socket_path, origin, extra_request_headers),
94             WebSocketStandardResponse(extra_response_headers));
95 
96     scoped_ptr<WebSocketHandshakeStreamBase> handshake(
97         create_helper.CreateBasicStream(socket_handle.Pass(), false));
98 
99     // If in future the implementation type returned by CreateBasicStream()
100     // changes, this static_cast will be wrong. However, in that case the test
101     // will fail and AddressSanitizer should identify the issue.
102     static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
103         ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
104 
105     HttpRequestInfo request_info;
106     request_info.url = GURL(socket_url);
107     request_info.method = "GET";
108     request_info.load_flags = LOAD_DISABLE_CACHE | LOAD_DO_NOT_PROMPT_FOR_LOGIN;
109     int rv = handshake->InitializeStream(
110         &request_info, DEFAULT_PRIORITY, BoundNetLog(), CompletionCallback());
111     EXPECT_EQ(OK, rv);
112 
113     HttpRequestHeaders headers;
114     headers.SetHeader("Host", "localhost");
115     headers.SetHeader("Connection", "Upgrade");
116     headers.SetHeader("Pragma", "no-cache");
117     headers.SetHeader("Cache-Control", "no-cache");
118     headers.SetHeader("Upgrade", "websocket");
119     headers.SetHeader("Origin", origin);
120     headers.SetHeader("Sec-WebSocket-Version", "13");
121     headers.SetHeader("User-Agent", "");
122     headers.SetHeader("Accept-Encoding", "gzip, deflate");
123     headers.SetHeader("Accept-Language", "en-us,fr");
124 
125     HttpResponseInfo response;
126     TestCompletionCallback dummy;
127 
128     rv = handshake->SendRequest(headers, &response, dummy.callback());
129 
130     EXPECT_EQ(OK, rv);
131 
132     rv = handshake->ReadResponseHeaders(dummy.callback());
133     EXPECT_EQ(OK, rv);
134     EXPECT_EQ(101, response.headers->response_code());
135     EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
136     EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
137     return handshake->Upgrade();
138   }
139 
140   MockClientSocketHandleFactory socket_handle_factory_;
141   TestConnectDelegate connect_delegate_;
142   std::string failure_message_;
143 };
144 
145 // Confirm that the basic case works as expected.
TEST_F(WebSocketHandshakeStreamCreateHelperTest,BasicStream)146 TEST_F(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
147   scoped_ptr<WebSocketStream> stream =
148       CreateAndInitializeStream("ws://localhost/", "/",
149                                 std::vector<std::string>(), "http://localhost/",
150                                 "", "");
151   EXPECT_EQ("", stream->GetExtensions());
152   EXPECT_EQ("", stream->GetSubProtocol());
153 }
154 
155 // Verify that the sub-protocols are passed through.
TEST_F(WebSocketHandshakeStreamCreateHelperTest,SubProtocols)156 TEST_F(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
157   std::vector<std::string> sub_protocols;
158   sub_protocols.push_back("chat");
159   sub_protocols.push_back("superchat");
160   scoped_ptr<WebSocketStream> stream =
161       CreateAndInitializeStream("ws://localhost/",
162                                 "/",
163                                 sub_protocols,
164                                 "http://localhost/",
165                                 "Sec-WebSocket-Protocol: chat, superchat\r\n",
166                                 "Sec-WebSocket-Protocol: superchat\r\n");
167   EXPECT_EQ("superchat", stream->GetSubProtocol());
168 }
169 
170 // Verify that extension name is available. Bad extension names are tested in
171 // websocket_stream_test.cc.
TEST_F(WebSocketHandshakeStreamCreateHelperTest,Extensions)172 TEST_F(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
173   scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream(
174       "ws://localhost/",
175       "/",
176       std::vector<std::string>(),
177       "http://localhost/",
178       "",
179       "Sec-WebSocket-Extensions: permessage-deflate\r\n");
180   EXPECT_EQ("permessage-deflate", stream->GetExtensions());
181 }
182 
183 // Verify that extension parameters are available. Bad parameters are tested in
184 // websocket_stream_test.cc.
TEST_F(WebSocketHandshakeStreamCreateHelperTest,ExtensionParameters)185 TEST_F(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
186   scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream(
187       "ws://localhost/",
188       "/",
189       std::vector<std::string>(),
190       "http://localhost/",
191       "",
192       "Sec-WebSocket-Extensions: permessage-deflate;"
193       " client_max_window_bits=14; server_max_window_bits=14;"
194       " server_no_context_takeover; client_no_context_takeover\r\n");
195 
196   EXPECT_EQ(
197       "permessage-deflate;"
198       " client_max_window_bits=14; server_max_window_bits=14;"
199       " server_no_context_takeover; client_no_context_takeover",
200       stream->GetExtensions());
201 }
202 
203 }  // namespace
204 }  // namespace net
205