1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/services/network/web_socket_impl.h"
6
7 #include "base/logging.h"
8 #include "base/message_loop/message_loop.h"
9 #include "mojo/common/handle_watcher.h"
10 #include "mojo/services/network/network_context.h"
11 #include "mojo/services/public/cpp/network/web_socket_read_queue.h"
12 #include "mojo/services/public/cpp/network/web_socket_write_queue.h"
13 #include "net/websockets/websocket_channel.h"
14 #include "net/websockets/websocket_errors.h"
15 #include "net/websockets/websocket_event_interface.h"
16 #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
17 #include "net/websockets/websocket_handshake_request_info.h"
18 #include "net/websockets/websocket_handshake_response_info.h"
19 #include "url/origin.h"
20
21 namespace mojo {
22
23 template <>
24 struct TypeConverter<net::WebSocketFrameHeader::OpCode,
25 WebSocket::MessageType> {
Convertmojo::TypeConverter26 static net::WebSocketFrameHeader::OpCode Convert(
27 WebSocket::MessageType type) {
28 DCHECK(type == WebSocket::MESSAGE_TYPE_CONTINUATION ||
29 type == WebSocket::MESSAGE_TYPE_TEXT ||
30 type == WebSocket::MESSAGE_TYPE_BINARY);
31 typedef net::WebSocketFrameHeader::OpCode OpCode;
32 // These compile asserts verify that the same underlying values are used for
33 // both types, so we can simply cast between them.
34 COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_CONTINUATION) ==
35 net::WebSocketFrameHeader::kOpCodeContinuation,
36 enum_values_must_match_for_opcode_continuation);
37 COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_TEXT) ==
38 net::WebSocketFrameHeader::kOpCodeText,
39 enum_values_must_match_for_opcode_text);
40 COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_BINARY) ==
41 net::WebSocketFrameHeader::kOpCodeBinary,
42 enum_values_must_match_for_opcode_binary);
43 return static_cast<OpCode>(type);
44 }
45 };
46
47 template <>
48 struct TypeConverter<WebSocket::MessageType,
49 net::WebSocketFrameHeader::OpCode> {
Convertmojo::TypeConverter50 static WebSocket::MessageType Convert(
51 net::WebSocketFrameHeader::OpCode type) {
52 DCHECK(type == net::WebSocketFrameHeader::kOpCodeContinuation ||
53 type == net::WebSocketFrameHeader::kOpCodeText ||
54 type == net::WebSocketFrameHeader::kOpCodeBinary);
55 return static_cast<WebSocket::MessageType>(type);
56 }
57 };
58
59 namespace {
60
61 typedef net::WebSocketEventInterface::ChannelState ChannelState;
62
63 struct WebSocketEventHandler : public net::WebSocketEventInterface {
64 public:
WebSocketEventHandlermojo::__anonb40c803d0111::WebSocketEventHandler65 WebSocketEventHandler(WebSocketClientPtr client)
66 : client_(client.Pass()) {
67 }
~WebSocketEventHandlermojo::__anonb40c803d0111::WebSocketEventHandler68 virtual ~WebSocketEventHandler() {}
69
70 private:
71 // net::WebSocketEventInterface methods:
72 virtual ChannelState OnAddChannelResponse(
73 bool fail,
74 const std::string& selected_subprotocol,
75 const std::string& extensions) OVERRIDE;
76 virtual ChannelState OnDataFrame(bool fin,
77 WebSocketMessageType type,
78 const std::vector<char>& data) OVERRIDE;
79 virtual ChannelState OnClosingHandshake() OVERRIDE;
80 virtual ChannelState OnFlowControl(int64 quota) OVERRIDE;
81 virtual ChannelState OnDropChannel(bool was_clean,
82 uint16 code,
83 const std::string& reason) OVERRIDE;
84 virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE;
85 virtual ChannelState OnStartOpeningHandshake(
86 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE;
87 virtual ChannelState OnFinishOpeningHandshake(
88 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE;
89 virtual ChannelState OnSSLCertificateError(
90 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
91 const GURL& url,
92 const net::SSLInfo& ssl_info,
93 bool fatal) OVERRIDE;
94
95 // Called once we've written to |receive_stream_|.
96 void DidWriteToReceiveStream(bool fin,
97 net::WebSocketFrameHeader::OpCode type,
98 uint32_t num_bytes,
99 const char* buffer);
100 WebSocketClientPtr client_;
101 ScopedDataPipeProducerHandle receive_stream_;
102 scoped_ptr<WebSocketWriteQueue> write_queue_;
103
104 DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
105 };
106
OnAddChannelResponse(bool fail,const std::string & selected_protocol,const std::string & extensions)107 ChannelState WebSocketEventHandler::OnAddChannelResponse(
108 bool fail,
109 const std::string& selected_protocol,
110 const std::string& extensions) {
111 DataPipe data_pipe;
112 receive_stream_ = data_pipe.producer_handle.Pass();
113 write_queue_.reset(new WebSocketWriteQueue(receive_stream_.get()));
114 client_->DidConnect(
115 fail, selected_protocol, extensions, data_pipe.consumer_handle.Pass());
116 if (fail)
117 return WebSocketEventInterface::CHANNEL_DELETED;
118 return WebSocketEventInterface::CHANNEL_ALIVE;
119 }
120
OnDataFrame(bool fin,net::WebSocketFrameHeader::OpCode type,const std::vector<char> & data)121 ChannelState WebSocketEventHandler::OnDataFrame(
122 bool fin,
123 net::WebSocketFrameHeader::OpCode type,
124 const std::vector<char>& data) {
125 uint32_t size = static_cast<uint32_t>(data.size());
126 write_queue_->Write(
127 &data[0], size,
128 base::Bind(&WebSocketEventHandler::DidWriteToReceiveStream,
129 base::Unretained(this),
130 fin, type, size));
131 return WebSocketEventInterface::CHANNEL_ALIVE;
132 }
133
OnClosingHandshake()134 ChannelState WebSocketEventHandler::OnClosingHandshake() {
135 return WebSocketEventInterface::CHANNEL_ALIVE;
136 }
137
OnFlowControl(int64 quota)138 ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) {
139 client_->DidReceiveFlowControl(quota);
140 return WebSocketEventInterface::CHANNEL_ALIVE;
141 }
142
OnDropChannel(bool was_clean,uint16 code,const std::string & reason)143 ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean,
144 uint16 code,
145 const std::string& reason) {
146 client_->DidClose(was_clean, code, reason);
147 return WebSocketEventInterface::CHANNEL_DELETED;
148 }
149
OnFailChannel(const std::string & message)150 ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) {
151 client_->DidFail(message);
152 return WebSocketEventInterface::CHANNEL_DELETED;
153 }
154
OnStartOpeningHandshake(scoped_ptr<net::WebSocketHandshakeRequestInfo> request)155 ChannelState WebSocketEventHandler::OnStartOpeningHandshake(
156 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) {
157 return WebSocketEventInterface::CHANNEL_ALIVE;
158 }
159
OnFinishOpeningHandshake(scoped_ptr<net::WebSocketHandshakeResponseInfo> response)160 ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
161 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) {
162 return WebSocketEventInterface::CHANNEL_ALIVE;
163 }
164
OnSSLCertificateError(scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,const GURL & url,const net::SSLInfo & ssl_info,bool fatal)165 ChannelState WebSocketEventHandler::OnSSLCertificateError(
166 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
167 const GURL& url,
168 const net::SSLInfo& ssl_info,
169 bool fatal) {
170 client_->DidFail("SSL Error");
171 return WebSocketEventInterface::CHANNEL_DELETED;
172 }
173
DidWriteToReceiveStream(bool fin,net::WebSocketFrameHeader::OpCode type,uint32_t num_bytes,const char * buffer)174 void WebSocketEventHandler::DidWriteToReceiveStream(
175 bool fin,
176 net::WebSocketFrameHeader::OpCode type,
177 uint32_t num_bytes,
178 const char* buffer) {
179 client_->DidReceiveData(
180 fin, ConvertTo<WebSocket::MessageType>(type), num_bytes);
181 }
182
183 } // namespace mojo
184
WebSocketImpl(NetworkContext * context)185 WebSocketImpl::WebSocketImpl(NetworkContext* context) : context_(context) {
186 }
187
~WebSocketImpl()188 WebSocketImpl::~WebSocketImpl() {
189 }
190
Connect(const String & url,Array<String> protocols,const String & origin,ScopedDataPipeConsumerHandle send_stream,WebSocketClientPtr client)191 void WebSocketImpl::Connect(const String& url,
192 Array<String> protocols,
193 const String& origin,
194 ScopedDataPipeConsumerHandle send_stream,
195 WebSocketClientPtr client) {
196 DCHECK(!channel_);
197 send_stream_ = send_stream.Pass();
198 read_queue_.reset(new WebSocketReadQueue(send_stream_.get()));
199 scoped_ptr<net::WebSocketEventInterface> event_interface(
200 new WebSocketEventHandler(client.Pass()));
201 channel_.reset(new net::WebSocketChannel(event_interface.Pass(),
202 context_->url_request_context()));
203 channel_->SendAddChannelRequest(GURL(url.get()),
204 protocols.To<std::vector<std::string> >(),
205 url::Origin(origin.get()));
206 }
207
Send(bool fin,WebSocket::MessageType type,uint32_t num_bytes)208 void WebSocketImpl::Send(bool fin,
209 WebSocket::MessageType type,
210 uint32_t num_bytes) {
211 DCHECK(channel_);
212 read_queue_->Read(num_bytes,
213 base::Bind(&WebSocketImpl::DidReadFromSendStream,
214 base::Unretained(this),
215 fin, type, num_bytes));
216 }
217
FlowControl(int64_t quota)218 void WebSocketImpl::FlowControl(int64_t quota) {
219 DCHECK(channel_);
220 channel_->SendFlowControl(quota);
221 }
222
Close(uint16_t code,const String & reason)223 void WebSocketImpl::Close(uint16_t code, const String& reason) {
224 DCHECK(channel_);
225 channel_->StartClosingHandshake(code, reason);
226 }
227
DidReadFromSendStream(bool fin,WebSocket::MessageType type,uint32_t num_bytes,const char * data)228 void WebSocketImpl::DidReadFromSendStream(bool fin,
229 WebSocket::MessageType type,
230 uint32_t num_bytes,
231 const char* data) {
232 std::vector<char> buffer(num_bytes);
233 memcpy(&buffer[0], data, num_bytes);
234 DCHECK(channel_);
235 channel_->SendFrame(
236 fin, ConvertTo<net::WebSocketFrameHeader::OpCode>(type), buffer);
237 }
238
239 } // namespace mojo
240