• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_basic_handshake_stream.h"
6 
7 #include <set>
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 #include "net/base/address_list.h"
13 #include "net/base/ip_address.h"
14 #include "net/base/ip_endpoint.h"
15 #include "net/base/net_errors.h"
16 #include "net/base/test_completion_callback.h"
17 #include "net/http/http_request_info.h"
18 #include "net/http/http_response_info.h"
19 #include "net/log/net_log_with_source.h"
20 #include "net/socket/client_socket_handle.h"
21 #include "net/socket/socket_test_util.h"
22 #include "net/socket/websocket_endpoint_lock_manager.h"
23 #include "net/traffic_annotation/network_traffic_annotation.h"
24 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
25 #include "net/websockets/websocket_test_util.h"
26 #include "testing/gmock/include/gmock/gmock.h"
27 #include "url/gurl.h"
28 #include "url/origin.h"
29 
30 namespace net {
31 namespace {
32 
TEST(WebSocketBasicHandshakeStreamTest,ConnectionClosedOnFailure)33 TEST(WebSocketBasicHandshakeStreamTest, ConnectionClosedOnFailure) {
34   std::string request = WebSocketStandardRequest(
35       "/", "www.example.org",
36       url::Origin::Create(GURL("http://origin.example.org")),
37       /*send_additional_request_headers=*/{}, /*extra_headers=*/{});
38   std::string response =
39       "HTTP/1.1 404 Not Found\r\n"
40       "Content-Length: 0\r\n"
41       "\r\n";
42   MockWrite writes[] = {MockWrite(SYNCHRONOUS, 0, request.c_str())};
43   MockRead reads[] = {MockRead(SYNCHRONOUS, 1, response.c_str()),
44                       MockRead(SYNCHRONOUS, ERR_IO_PENDING, 2)};
45   IPEndPoint end_point(IPAddress(127, 0, 0, 1), 80);
46   SequencedSocketData sequenced_socket_data(
47       MockConnect(SYNCHRONOUS, OK, end_point), reads, writes);
48   auto socket = std::make_unique<MockTCPClientSocket>(
49       AddressList(end_point), nullptr, &sequenced_socket_data);
50   const int connect_result = socket->Connect(CompletionOnceCallback());
51   EXPECT_EQ(connect_result, OK);
52   const MockTCPClientSocket* const socket_ptr = socket.get();
53   auto handle = std::make_unique<ClientSocketHandle>();
54   handle->SetSocket(std::move(socket));
55   DummyConnectDelegate delegate;
56   WebSocketEndpointLockManager endpoint_lock_manager;
57   TestWebSocketStreamRequestAPI stream_request_api;
58   std::vector<std::string> extensions = {
59       "permessage-deflate; client_max_window_bits"};
60   WebSocketBasicHandshakeStream basic_handshake_stream(
61       std::move(handle), &delegate, false, {}, extensions, &stream_request_api,
62       &endpoint_lock_manager);
63   basic_handshake_stream.SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
64   HttpRequestInfo request_info;
65   request_info.url = GURL("ws://www.example.com/");
66   request_info.method = "GET";
67   request_info.traffic_annotation =
68       MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
69   TestCompletionCallback callback1;
70   NetLogWithSource net_log;
71   basic_handshake_stream.RegisterRequest(&request_info);
72   const int result1 =
73       callback1.GetResult(basic_handshake_stream.InitializeStream(
74           true, LOWEST, net_log, callback1.callback()));
75   EXPECT_EQ(result1, OK);
76 
77   auto request_headers = WebSocketCommonTestHeaders();
78   HttpResponseInfo response_info;
79   TestCompletionCallback callback2;
80   const int result2 = callback2.GetResult(basic_handshake_stream.SendRequest(
81       request_headers, &response_info, callback2.callback()));
82   EXPECT_EQ(result2, OK);
83 
84   TestCompletionCallback callback3;
85   const int result3 = callback3.GetResult(
86       basic_handshake_stream.ReadResponseHeaders(callback2.callback()));
87   EXPECT_EQ(result3, ERR_INVALID_RESPONSE);
88 
89   EXPECT_FALSE(socket_ptr->IsConnected());
90 }
91 
TEST(WebSocketBasicHandshakeStreamTest,DnsAliasesCanBeAccessed)92 TEST(WebSocketBasicHandshakeStreamTest, DnsAliasesCanBeAccessed) {
93   std::string request = WebSocketStandardRequest(
94       "/", "www.example.org",
95       url::Origin::Create(GURL("http://origin.example.org")),
96       /*send_additional_request_headers=*/{}, /*extra_headers=*/{});
97   std::string response = WebSocketStandardResponse("");
98   MockWrite writes[] = {MockWrite(SYNCHRONOUS, 0, request.c_str())};
99   MockRead reads[] = {MockRead(SYNCHRONOUS, 1, response.c_str()),
100                       MockRead(SYNCHRONOUS, ERR_IO_PENDING, 2)};
101 
102   IPEndPoint end_point(IPAddress(127, 0, 0, 1), 80);
103   SequencedSocketData sequenced_socket_data(
104       MockConnect(SYNCHRONOUS, OK, end_point), reads, writes);
105   auto socket = std::make_unique<MockTCPClientSocket>(
106       AddressList(end_point), nullptr, &sequenced_socket_data);
107   const int connect_result = socket->Connect(CompletionOnceCallback());
108   EXPECT_EQ(connect_result, OK);
109 
110   std::set<std::string> aliases({"alias1", "alias2", "www.example.org"});
111   socket->SetDnsAliases(aliases);
112   EXPECT_THAT(
113       socket->GetDnsAliases(),
114       testing::UnorderedElementsAre("alias1", "alias2", "www.example.org"));
115 
116   const MockTCPClientSocket* const socket_ptr = socket.get();
117   auto handle = std::make_unique<ClientSocketHandle>();
118   handle->SetSocket(std::move(socket));
119   EXPECT_THAT(
120       handle->socket()->GetDnsAliases(),
121       testing::UnorderedElementsAre("alias1", "alias2", "www.example.org"));
122 
123   DummyConnectDelegate delegate;
124   WebSocketEndpointLockManager endpoint_lock_manager;
125   TestWebSocketStreamRequestAPI stream_request_api;
126   std::vector<std::string> extensions = {
127       "permessage-deflate; client_max_window_bits"};
128   WebSocketBasicHandshakeStream basic_handshake_stream(
129       std::move(handle), &delegate, false, {}, extensions, &stream_request_api,
130       &endpoint_lock_manager);
131   basic_handshake_stream.SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
132   HttpRequestInfo request_info;
133   request_info.url = GURL("ws://www.example.com/");
134   request_info.method = "GET";
135   request_info.traffic_annotation =
136       MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
137   TestCompletionCallback callback1;
138   NetLogWithSource net_log;
139   basic_handshake_stream.RegisterRequest(&request_info);
140   const int result1 =
141       callback1.GetResult(basic_handshake_stream.InitializeStream(
142           true, LOWEST, net_log, callback1.callback()));
143   EXPECT_EQ(result1, OK);
144 
145   auto request_headers = WebSocketCommonTestHeaders();
146   HttpResponseInfo response_info;
147   TestCompletionCallback callback2;
148   const int result2 = callback2.GetResult(basic_handshake_stream.SendRequest(
149       request_headers, &response_info, callback2.callback()));
150   EXPECT_EQ(result2, OK);
151 
152   TestCompletionCallback callback3;
153   const int result3 = callback3.GetResult(
154       basic_handshake_stream.ReadResponseHeaders(callback2.callback()));
155   EXPECT_EQ(result3, OK);
156 
157   EXPECT_TRUE(socket_ptr->IsConnected());
158 
159   EXPECT_THAT(basic_handshake_stream.GetDnsAliases(),
160               testing::ElementsAre("alias1", "alias2", "www.example.org"));
161 }
162 
163 }  // namespace
164 }  // namespace net
165