1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "remoting/protocol/channel_multiplexer.h"
6
7 #include <string.h>
8
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/location.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/stl_util.h"
14 #include "base/thread_task_runner_handle.h"
15 #include "net/base/net_errors.h"
16 #include "net/socket/stream_socket.h"
17 #include "remoting/protocol/util.h"
18
19 namespace remoting {
20 namespace protocol {
21
22 namespace {
23 const int kChannelIdUnknown = -1;
24 const int kMaxPacketSize = 1024;
25
26 class PendingPacket {
27 public:
PendingPacket(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)28 PendingPacket(scoped_ptr<MultiplexPacket> packet,
29 const base::Closure& done_task)
30 : packet(packet.Pass()),
31 done_task(done_task),
32 pos(0U) {
33 }
~PendingPacket()34 ~PendingPacket() {
35 done_task.Run();
36 }
37
is_empty()38 bool is_empty() { return pos >= packet->data().size(); }
39
Read(char * buffer,size_t size)40 int Read(char* buffer, size_t size) {
41 size = std::min(size, packet->data().size() - pos);
42 memcpy(buffer, packet->data().data() + pos, size);
43 pos += size;
44 return size;
45 }
46
47 private:
48 scoped_ptr<MultiplexPacket> packet;
49 base::Closure done_task;
50 size_t pos;
51
52 DISALLOW_COPY_AND_ASSIGN(PendingPacket);
53 };
54
55 } // namespace
56
57 const char ChannelMultiplexer::kMuxChannelName[] = "mux";
58
59 struct ChannelMultiplexer::PendingChannel {
PendingChannelremoting::protocol::ChannelMultiplexer::PendingChannel60 PendingChannel(const std::string& name,
61 const StreamChannelCallback& callback)
62 : name(name), callback(callback) {
63 }
64 std::string name;
65 StreamChannelCallback callback;
66 };
67
68 class ChannelMultiplexer::MuxChannel {
69 public:
70 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
71 int send_id);
72 ~MuxChannel();
73
name()74 const std::string& name() { return name_; }
receive_id()75 int receive_id() { return receive_id_; }
set_receive_id(int id)76 void set_receive_id(int id) { receive_id_ = id; }
77
78 // Called by ChannelMultiplexer.
79 scoped_ptr<net::StreamSocket> CreateSocket();
80 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
81 const base::Closure& done_task);
82 void OnWriteFailed();
83
84 // Called by MuxSocket.
85 void OnSocketDestroyed();
86 bool DoWrite(scoped_ptr<MultiplexPacket> packet,
87 const base::Closure& done_task);
88 int DoRead(net::IOBuffer* buffer, int buffer_len);
89
90 private:
91 ChannelMultiplexer* multiplexer_;
92 std::string name_;
93 int send_id_;
94 bool id_sent_;
95 int receive_id_;
96 MuxSocket* socket_;
97 std::list<PendingPacket*> pending_packets_;
98
99 DISALLOW_COPY_AND_ASSIGN(MuxChannel);
100 };
101
102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
103 public base::NonThreadSafe,
104 public base::SupportsWeakPtr<MuxSocket> {
105 public:
106 MuxSocket(MuxChannel* channel);
107 virtual ~MuxSocket();
108
109 void OnWriteComplete();
110 void OnWriteFailed();
111 void OnPacketReceived();
112
113 // net::StreamSocket interface.
114 virtual int Read(net::IOBuffer* buffer, int buffer_len,
115 const net::CompletionCallback& callback) OVERRIDE;
116 virtual int Write(net::IOBuffer* buffer, int buffer_len,
117 const net::CompletionCallback& callback) OVERRIDE;
118
SetReceiveBufferSize(int32 size)119 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
120 NOTIMPLEMENTED();
121 return false;
122 }
SetSendBufferSize(int32 size)123 virtual bool SetSendBufferSize(int32 size) OVERRIDE {
124 NOTIMPLEMENTED();
125 return false;
126 }
127
Connect(const net::CompletionCallback & callback)128 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
129 NOTIMPLEMENTED();
130 return net::ERR_FAILED;
131 }
Disconnect()132 virtual void Disconnect() OVERRIDE {
133 NOTIMPLEMENTED();
134 }
IsConnected() const135 virtual bool IsConnected() const OVERRIDE {
136 NOTIMPLEMENTED();
137 return true;
138 }
IsConnectedAndIdle() const139 virtual bool IsConnectedAndIdle() const OVERRIDE {
140 NOTIMPLEMENTED();
141 return false;
142 }
GetPeerAddress(net::IPEndPoint * address) const143 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
144 NOTIMPLEMENTED();
145 return net::ERR_FAILED;
146 }
GetLocalAddress(net::IPEndPoint * address) const147 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
148 NOTIMPLEMENTED();
149 return net::ERR_FAILED;
150 }
NetLog() const151 virtual const net::BoundNetLog& NetLog() const OVERRIDE {
152 NOTIMPLEMENTED();
153 return net_log_;
154 }
SetSubresourceSpeculation()155 virtual void SetSubresourceSpeculation() OVERRIDE {
156 NOTIMPLEMENTED();
157 }
SetOmniboxSpeculation()158 virtual void SetOmniboxSpeculation() OVERRIDE {
159 NOTIMPLEMENTED();
160 }
WasEverUsed() const161 virtual bool WasEverUsed() const OVERRIDE {
162 return true;
163 }
UsingTCPFastOpen() const164 virtual bool UsingTCPFastOpen() const OVERRIDE {
165 return false;
166 }
WasNpnNegotiated() const167 virtual bool WasNpnNegotiated() const OVERRIDE {
168 return false;
169 }
GetNegotiatedProtocol() const170 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
171 return net::kProtoUnknown;
172 }
GetSSLInfo(net::SSLInfo * ssl_info)173 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
174 NOTIMPLEMENTED();
175 return false;
176 }
177
178 private:
179 MuxChannel* channel_;
180
181 net::CompletionCallback read_callback_;
182 scoped_refptr<net::IOBuffer> read_buffer_;
183 int read_buffer_size_;
184
185 bool write_pending_;
186 int write_result_;
187 net::CompletionCallback write_callback_;
188
189 net::BoundNetLog net_log_;
190
191 DISALLOW_COPY_AND_ASSIGN(MuxSocket);
192 };
193
194
MuxChannel(ChannelMultiplexer * multiplexer,const std::string & name,int send_id)195 ChannelMultiplexer::MuxChannel::MuxChannel(
196 ChannelMultiplexer* multiplexer,
197 const std::string& name,
198 int send_id)
199 : multiplexer_(multiplexer),
200 name_(name),
201 send_id_(send_id),
202 id_sent_(false),
203 receive_id_(kChannelIdUnknown),
204 socket_(NULL) {
205 }
206
~MuxChannel()207 ChannelMultiplexer::MuxChannel::~MuxChannel() {
208 // Socket must be destroyed before the channel.
209 DCHECK(!socket_);
210 STLDeleteElements(&pending_packets_);
211 }
212
CreateSocket()213 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
214 DCHECK(!socket_); // Can't create more than one socket per channel.
215 scoped_ptr<MuxSocket> result(new MuxSocket(this));
216 socket_ = result.get();
217 return result.PassAs<net::StreamSocket>();
218 }
219
OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)220 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
221 scoped_ptr<MultiplexPacket> packet,
222 const base::Closure& done_task) {
223 DCHECK_EQ(packet->channel_id(), receive_id_);
224 if (packet->data().size() > 0) {
225 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
226 if (socket_) {
227 // Notify the socket that we have more data.
228 socket_->OnPacketReceived();
229 }
230 }
231 }
232
OnWriteFailed()233 void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
234 if (socket_)
235 socket_->OnWriteFailed();
236 }
237
OnSocketDestroyed()238 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
239 DCHECK(socket_);
240 socket_ = NULL;
241 }
242
DoWrite(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)243 bool ChannelMultiplexer::MuxChannel::DoWrite(
244 scoped_ptr<MultiplexPacket> packet,
245 const base::Closure& done_task) {
246 packet->set_channel_id(send_id_);
247 if (!id_sent_) {
248 packet->set_channel_name(name_);
249 id_sent_ = true;
250 }
251 return multiplexer_->DoWrite(packet.Pass(), done_task);
252 }
253
DoRead(net::IOBuffer * buffer,int buffer_len)254 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
255 int buffer_len) {
256 int pos = 0;
257 while (buffer_len > 0 && !pending_packets_.empty()) {
258 DCHECK(!pending_packets_.front()->is_empty());
259 int result = pending_packets_.front()->Read(
260 buffer->data() + pos, buffer_len);
261 DCHECK_LE(result, buffer_len);
262 pos += result;
263 buffer_len -= pos;
264 if (pending_packets_.front()->is_empty()) {
265 delete pending_packets_.front();
266 pending_packets_.erase(pending_packets_.begin());
267 }
268 }
269 return pos;
270 }
271
MuxSocket(MuxChannel * channel)272 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
273 : channel_(channel),
274 read_buffer_size_(0),
275 write_pending_(false),
276 write_result_(0) {
277 }
278
~MuxSocket()279 ChannelMultiplexer::MuxSocket::~MuxSocket() {
280 channel_->OnSocketDestroyed();
281 }
282
Read(net::IOBuffer * buffer,int buffer_len,const net::CompletionCallback & callback)283 int ChannelMultiplexer::MuxSocket::Read(
284 net::IOBuffer* buffer, int buffer_len,
285 const net::CompletionCallback& callback) {
286 DCHECK(CalledOnValidThread());
287 DCHECK(read_callback_.is_null());
288
289 int result = channel_->DoRead(buffer, buffer_len);
290 if (result == 0) {
291 read_buffer_ = buffer;
292 read_buffer_size_ = buffer_len;
293 read_callback_ = callback;
294 return net::ERR_IO_PENDING;
295 }
296 return result;
297 }
298
Write(net::IOBuffer * buffer,int buffer_len,const net::CompletionCallback & callback)299 int ChannelMultiplexer::MuxSocket::Write(
300 net::IOBuffer* buffer, int buffer_len,
301 const net::CompletionCallback& callback) {
302 DCHECK(CalledOnValidThread());
303
304 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
305 size_t size = std::min(kMaxPacketSize, buffer_len);
306 packet->mutable_data()->assign(buffer->data(), size);
307
308 write_pending_ = true;
309 bool result = channel_->DoWrite(packet.Pass(), base::Bind(
310 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
311
312 if (!result) {
313 // Cannot complete the write, e.g. if the connection has been terminated.
314 return net::ERR_FAILED;
315 }
316
317 // OnWriteComplete() might be called above synchronously.
318 if (write_pending_) {
319 DCHECK(write_callback_.is_null());
320 write_callback_ = callback;
321 write_result_ = size;
322 return net::ERR_IO_PENDING;
323 }
324
325 return size;
326 }
327
OnWriteComplete()328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
329 write_pending_ = false;
330 if (!write_callback_.is_null()) {
331 net::CompletionCallback cb;
332 std::swap(cb, write_callback_);
333 cb.Run(write_result_);
334 }
335 }
336
OnWriteFailed()337 void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
338 if (!write_callback_.is_null()) {
339 net::CompletionCallback cb;
340 std::swap(cb, write_callback_);
341 cb.Run(net::ERR_FAILED);
342 }
343 }
344
OnPacketReceived()345 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
346 if (!read_callback_.is_null()) {
347 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
348 read_buffer_ = NULL;
349 DCHECK_GT(result, 0);
350 net::CompletionCallback cb;
351 std::swap(cb, read_callback_);
352 cb.Run(result);
353 }
354 }
355
ChannelMultiplexer(ChannelFactory * factory,const std::string & base_channel_name)356 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory,
357 const std::string& base_channel_name)
358 : base_channel_factory_(factory),
359 base_channel_name_(base_channel_name),
360 next_channel_id_(0),
361 weak_factory_(this) {
362 }
363
~ChannelMultiplexer()364 ChannelMultiplexer::~ChannelMultiplexer() {
365 DCHECK(pending_channels_.empty());
366 STLDeleteValues(&channels_);
367
368 // Cancel creation of the base channel if it hasn't finished.
369 if (base_channel_factory_)
370 base_channel_factory_->CancelChannelCreation(base_channel_name_);
371 }
372
CreateStreamChannel(const std::string & name,const StreamChannelCallback & callback)373 void ChannelMultiplexer::CreateStreamChannel(
374 const std::string& name,
375 const StreamChannelCallback& callback) {
376 if (base_channel_.get()) {
377 // Already have |base_channel_|. Create new multiplexed channel
378 // synchronously.
379 callback.Run(GetOrCreateChannel(name)->CreateSocket());
380 } else if (!base_channel_.get() && !base_channel_factory_) {
381 // Fail synchronously if we failed to create |base_channel_|.
382 callback.Run(scoped_ptr<net::StreamSocket>());
383 } else {
384 // Still waiting for the |base_channel_|.
385 pending_channels_.push_back(PendingChannel(name, callback));
386
387 // If this is the first multiplexed channel then create the base channel.
388 if (pending_channels_.size() == 1U) {
389 base_channel_factory_->CreateStreamChannel(
390 base_channel_name_,
391 base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
392 base::Unretained(this)));
393 }
394 }
395 }
396
CreateDatagramChannel(const std::string & name,const DatagramChannelCallback & callback)397 void ChannelMultiplexer::CreateDatagramChannel(
398 const std::string& name,
399 const DatagramChannelCallback& callback) {
400 NOTIMPLEMENTED();
401 callback.Run(scoped_ptr<net::Socket>());
402 }
403
CancelChannelCreation(const std::string & name)404 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
405 for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
406 it != pending_channels_.end(); ++it) {
407 if (it->name == name) {
408 pending_channels_.erase(it);
409 return;
410 }
411 }
412 }
413
OnBaseChannelReady(scoped_ptr<net::StreamSocket> socket)414 void ChannelMultiplexer::OnBaseChannelReady(
415 scoped_ptr<net::StreamSocket> socket) {
416 base_channel_factory_ = NULL;
417 base_channel_ = socket.Pass();
418
419 if (base_channel_.get()) {
420 // Initialize reader and writer.
421 reader_.Init(base_channel_.get(),
422 base::Bind(&ChannelMultiplexer::OnIncomingPacket,
423 base::Unretained(this)));
424 writer_.Init(base_channel_.get(),
425 base::Bind(&ChannelMultiplexer::OnWriteFailed,
426 base::Unretained(this)));
427 }
428
429 DoCreatePendingChannels();
430 }
431
DoCreatePendingChannels()432 void ChannelMultiplexer::DoCreatePendingChannels() {
433 if (pending_channels_.empty())
434 return;
435
436 // Every time this function is called it connects a single channel and posts a
437 // separate task to connect other channels. This is necessary because the
438 // callback may destroy the multiplexer or somehow else modify
439 // |pending_channels_| list (e.g. call CancelChannelCreation()).
440 base::ThreadTaskRunnerHandle::Get()->PostTask(
441 FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
442 weak_factory_.GetWeakPtr()));
443
444 PendingChannel c = pending_channels_.front();
445 pending_channels_.erase(pending_channels_.begin());
446 scoped_ptr<net::StreamSocket> socket;
447 if (base_channel_.get())
448 socket = GetOrCreateChannel(c.name)->CreateSocket();
449 c.callback.Run(socket.Pass());
450 }
451
GetOrCreateChannel(const std::string & name)452 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
453 const std::string& name) {
454 // Check if we already have a channel with the requested name.
455 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
456 if (it != channels_.end())
457 return it->second;
458
459 // Create a new channel if we haven't found existing one.
460 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
461 ++next_channel_id_;
462 channels_[channel->name()] = channel;
463 return channel;
464 }
465
466
OnWriteFailed(int error)467 void ChannelMultiplexer::OnWriteFailed(int error) {
468 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
469 it != channels_.end(); ++it) {
470 base::ThreadTaskRunnerHandle::Get()->PostTask(
471 FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
472 weak_factory_.GetWeakPtr(), it->second->name()));
473 }
474 }
475
NotifyWriteFailed(const std::string & name)476 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
477 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
478 if (it != channels_.end()) {
479 it->second->OnWriteFailed();
480 }
481 }
482
OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)483 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
484 const base::Closure& done_task) {
485 if (!packet->has_channel_id()) {
486 LOG(ERROR) << "Received packet without channel_id.";
487 done_task.Run();
488 return;
489 }
490
491 int receive_id = packet->channel_id();
492 MuxChannel* channel = NULL;
493 std::map<int, MuxChannel*>::iterator it =
494 channels_by_receive_id_.find(receive_id);
495 if (it != channels_by_receive_id_.end()) {
496 channel = it->second;
497 } else {
498 // This is a new |channel_id| we haven't seen before. Look it up by name.
499 if (!packet->has_channel_name()) {
500 LOG(ERROR) << "Received packet with unknown channel_id and "
501 "without channel_name.";
502 done_task.Run();
503 return;
504 }
505 channel = GetOrCreateChannel(packet->channel_name());
506 channel->set_receive_id(receive_id);
507 channels_by_receive_id_[receive_id] = channel;
508 }
509
510 channel->OnIncomingPacket(packet.Pass(), done_task);
511 }
512
DoWrite(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)513 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
514 const base::Closure& done_task) {
515 return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
516 }
517
518 } // namespace protocol
519 } // namespace remoting
520