1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_basic_handshake_stream.h"
6
7 #include <set>
8 #include <string>
9 #include <utility>
10 #include <vector>
11
12 #include "net/base/address_list.h"
13 #include "net/base/ip_address.h"
14 #include "net/base/ip_endpoint.h"
15 #include "net/base/net_errors.h"
16 #include "net/base/test_completion_callback.h"
17 #include "net/http/http_request_info.h"
18 #include "net/http/http_response_info.h"
19 #include "net/log/net_log_with_source.h"
20 #include "net/socket/client_socket_handle.h"
21 #include "net/socket/socket_test_util.h"
22 #include "net/socket/websocket_endpoint_lock_manager.h"
23 #include "net/traffic_annotation/network_traffic_annotation.h"
24 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
25 #include "net/websockets/websocket_test_util.h"
26 #include "testing/gmock/include/gmock/gmock.h"
27 #include "url/gurl.h"
28 #include "url/origin.h"
29
30 namespace net {
31 namespace {
32
TEST(WebSocketBasicHandshakeStreamTest,ConnectionClosedOnFailure)33 TEST(WebSocketBasicHandshakeStreamTest, ConnectionClosedOnFailure) {
34 std::string request = WebSocketStandardRequest(
35 "/", "www.example.org",
36 url::Origin::Create(GURL("http://origin.example.org")),
37 /*send_additional_request_headers=*/{}, /*extra_headers=*/{});
38 std::string response =
39 "HTTP/1.1 404 Not Found\r\n"
40 "Content-Length: 0\r\n"
41 "\r\n";
42 MockWrite writes[] = {MockWrite(SYNCHRONOUS, 0, request.c_str())};
43 MockRead reads[] = {MockRead(SYNCHRONOUS, 1, response.c_str()),
44 MockRead(SYNCHRONOUS, ERR_IO_PENDING, 2)};
45 IPEndPoint end_point(IPAddress(127, 0, 0, 1), 80);
46 SequencedSocketData sequenced_socket_data(
47 MockConnect(SYNCHRONOUS, OK, end_point), reads, writes);
48 auto socket = std::make_unique<MockTCPClientSocket>(
49 AddressList(end_point), nullptr, &sequenced_socket_data);
50 const int connect_result = socket->Connect(CompletionOnceCallback());
51 EXPECT_EQ(connect_result, OK);
52 const MockTCPClientSocket* const socket_ptr = socket.get();
53 auto handle = std::make_unique<ClientSocketHandle>();
54 handle->SetSocket(std::move(socket));
55 DummyConnectDelegate delegate;
56 WebSocketEndpointLockManager endpoint_lock_manager;
57 TestWebSocketStreamRequestAPI stream_request_api;
58 std::vector<std::string> extensions = {
59 "permessage-deflate; client_max_window_bits"};
60 WebSocketBasicHandshakeStream basic_handshake_stream(
61 std::move(handle), &delegate, false, {}, extensions, &stream_request_api,
62 &endpoint_lock_manager);
63 basic_handshake_stream.SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
64 HttpRequestInfo request_info;
65 request_info.url = GURL("ws://www.example.com/");
66 request_info.method = "GET";
67 request_info.traffic_annotation =
68 MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
69 TestCompletionCallback callback1;
70 NetLogWithSource net_log;
71 basic_handshake_stream.RegisterRequest(&request_info);
72 const int result1 =
73 callback1.GetResult(basic_handshake_stream.InitializeStream(
74 true, LOWEST, net_log, callback1.callback()));
75 EXPECT_EQ(result1, OK);
76
77 auto request_headers = WebSocketCommonTestHeaders();
78 HttpResponseInfo response_info;
79 TestCompletionCallback callback2;
80 const int result2 = callback2.GetResult(basic_handshake_stream.SendRequest(
81 request_headers, &response_info, callback2.callback()));
82 EXPECT_EQ(result2, OK);
83
84 TestCompletionCallback callback3;
85 const int result3 = callback3.GetResult(
86 basic_handshake_stream.ReadResponseHeaders(callback2.callback()));
87 EXPECT_EQ(result3, ERR_INVALID_RESPONSE);
88
89 EXPECT_FALSE(socket_ptr->IsConnected());
90 }
91
TEST(WebSocketBasicHandshakeStreamTest,DnsAliasesCanBeAccessed)92 TEST(WebSocketBasicHandshakeStreamTest, DnsAliasesCanBeAccessed) {
93 std::string request = WebSocketStandardRequest(
94 "/", "www.example.org",
95 url::Origin::Create(GURL("http://origin.example.org")),
96 /*send_additional_request_headers=*/{}, /*extra_headers=*/{});
97 std::string response = WebSocketStandardResponse("");
98 MockWrite writes[] = {MockWrite(SYNCHRONOUS, 0, request.c_str())};
99 MockRead reads[] = {MockRead(SYNCHRONOUS, 1, response.c_str()),
100 MockRead(SYNCHRONOUS, ERR_IO_PENDING, 2)};
101
102 IPEndPoint end_point(IPAddress(127, 0, 0, 1), 80);
103 SequencedSocketData sequenced_socket_data(
104 MockConnect(SYNCHRONOUS, OK, end_point), reads, writes);
105 auto socket = std::make_unique<MockTCPClientSocket>(
106 AddressList(end_point), nullptr, &sequenced_socket_data);
107 const int connect_result = socket->Connect(CompletionOnceCallback());
108 EXPECT_EQ(connect_result, OK);
109
110 std::set<std::string> aliases({"alias1", "alias2", "www.example.org"});
111 socket->SetDnsAliases(aliases);
112 EXPECT_THAT(
113 socket->GetDnsAliases(),
114 testing::UnorderedElementsAre("alias1", "alias2", "www.example.org"));
115
116 const MockTCPClientSocket* const socket_ptr = socket.get();
117 auto handle = std::make_unique<ClientSocketHandle>();
118 handle->SetSocket(std::move(socket));
119 EXPECT_THAT(
120 handle->socket()->GetDnsAliases(),
121 testing::UnorderedElementsAre("alias1", "alias2", "www.example.org"));
122
123 DummyConnectDelegate delegate;
124 WebSocketEndpointLockManager endpoint_lock_manager;
125 TestWebSocketStreamRequestAPI stream_request_api;
126 std::vector<std::string> extensions = {
127 "permessage-deflate; client_max_window_bits"};
128 WebSocketBasicHandshakeStream basic_handshake_stream(
129 std::move(handle), &delegate, false, {}, extensions, &stream_request_api,
130 &endpoint_lock_manager);
131 basic_handshake_stream.SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
132 HttpRequestInfo request_info;
133 request_info.url = GURL("ws://www.example.com/");
134 request_info.method = "GET";
135 request_info.traffic_annotation =
136 MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
137 TestCompletionCallback callback1;
138 NetLogWithSource net_log;
139 basic_handshake_stream.RegisterRequest(&request_info);
140 const int result1 =
141 callback1.GetResult(basic_handshake_stream.InitializeStream(
142 true, LOWEST, net_log, callback1.callback()));
143 EXPECT_EQ(result1, OK);
144
145 auto request_headers = WebSocketCommonTestHeaders();
146 HttpResponseInfo response_info;
147 TestCompletionCallback callback2;
148 const int result2 = callback2.GetResult(basic_handshake_stream.SendRequest(
149 request_headers, &response_info, callback2.callback()));
150 EXPECT_EQ(result2, OK);
151
152 TestCompletionCallback callback3;
153 const int result3 = callback3.GetResult(
154 basic_handshake_stream.ReadResponseHeaders(callback2.callback()));
155 EXPECT_EQ(result3, OK);
156
157 EXPECT_TRUE(socket_ptr->IsConnected());
158
159 EXPECT_THAT(basic_handshake_stream.GetDnsAliases(),
160 testing::ElementsAre("alias1", "alias2", "www.example.org"));
161 }
162
163 } // namespace
164 } // namespace net
165