• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "remoting/protocol/channel_multiplexer.h"
6 
7 #include <string.h>
8 
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/location.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/stl_util.h"
14 #include "base/thread_task_runner_handle.h"
15 #include "net/base/net_errors.h"
16 #include "net/socket/stream_socket.h"
17 #include "remoting/protocol/util.h"
18 
19 namespace remoting {
20 namespace protocol {
21 
22 namespace {
23 const int kChannelIdUnknown = -1;
24 const int kMaxPacketSize = 1024;
25 
26 class PendingPacket {
27  public:
PendingPacket(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)28   PendingPacket(scoped_ptr<MultiplexPacket> packet,
29                 const base::Closure& done_task)
30       : packet(packet.Pass()),
31         done_task(done_task),
32         pos(0U) {
33   }
~PendingPacket()34   ~PendingPacket() {
35     done_task.Run();
36   }
37 
is_empty()38   bool is_empty() { return pos >= packet->data().size(); }
39 
Read(char * buffer,size_t size)40   int Read(char* buffer, size_t size) {
41     size = std::min(size, packet->data().size() - pos);
42     memcpy(buffer, packet->data().data() + pos, size);
43     pos += size;
44     return size;
45   }
46 
47  private:
48   scoped_ptr<MultiplexPacket> packet;
49   base::Closure done_task;
50   size_t pos;
51 
52   DISALLOW_COPY_AND_ASSIGN(PendingPacket);
53 };
54 
55 }  // namespace
56 
57 const char ChannelMultiplexer::kMuxChannelName[] = "mux";
58 
59 struct ChannelMultiplexer::PendingChannel {
PendingChannelremoting::protocol::ChannelMultiplexer::PendingChannel60   PendingChannel(const std::string& name,
61                  const StreamChannelCallback& callback)
62       : name(name), callback(callback) {
63   }
64   std::string name;
65   StreamChannelCallback callback;
66 };
67 
68 class ChannelMultiplexer::MuxChannel {
69  public:
70   MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
71              int send_id);
72   ~MuxChannel();
73 
name()74   const std::string& name() { return name_; }
receive_id()75   int receive_id() { return receive_id_; }
set_receive_id(int id)76   void set_receive_id(int id) { receive_id_ = id; }
77 
78   // Called by ChannelMultiplexer.
79   scoped_ptr<net::StreamSocket> CreateSocket();
80   void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
81                         const base::Closure& done_task);
82   void OnWriteFailed();
83 
84   // Called by MuxSocket.
85   void OnSocketDestroyed();
86   bool DoWrite(scoped_ptr<MultiplexPacket> packet,
87                const base::Closure& done_task);
88   int DoRead(net::IOBuffer* buffer, int buffer_len);
89 
90  private:
91   ChannelMultiplexer* multiplexer_;
92   std::string name_;
93   int send_id_;
94   bool id_sent_;
95   int receive_id_;
96   MuxSocket* socket_;
97   std::list<PendingPacket*> pending_packets_;
98 
99   DISALLOW_COPY_AND_ASSIGN(MuxChannel);
100 };
101 
102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
103                                       public base::NonThreadSafe,
104                                       public base::SupportsWeakPtr<MuxSocket> {
105  public:
106   MuxSocket(MuxChannel* channel);
107   virtual ~MuxSocket();
108 
109   void OnWriteComplete();
110   void OnWriteFailed();
111   void OnPacketReceived();
112 
113   // net::StreamSocket interface.
114   virtual int Read(net::IOBuffer* buffer, int buffer_len,
115                    const net::CompletionCallback& callback) OVERRIDE;
116   virtual int Write(net::IOBuffer* buffer, int buffer_len,
117                     const net::CompletionCallback& callback) OVERRIDE;
118 
SetReceiveBufferSize(int32 size)119   virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
120     NOTIMPLEMENTED();
121     return false;
122   }
SetSendBufferSize(int32 size)123   virtual bool SetSendBufferSize(int32 size) OVERRIDE {
124     NOTIMPLEMENTED();
125     return false;
126   }
127 
Connect(const net::CompletionCallback & callback)128   virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
129     NOTIMPLEMENTED();
130     return net::ERR_FAILED;
131   }
Disconnect()132   virtual void Disconnect() OVERRIDE {
133     NOTIMPLEMENTED();
134   }
IsConnected() const135   virtual bool IsConnected() const OVERRIDE {
136     NOTIMPLEMENTED();
137     return true;
138   }
IsConnectedAndIdle() const139   virtual bool IsConnectedAndIdle() const OVERRIDE {
140     NOTIMPLEMENTED();
141     return false;
142   }
GetPeerAddress(net::IPEndPoint * address) const143   virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
144     NOTIMPLEMENTED();
145     return net::ERR_FAILED;
146   }
GetLocalAddress(net::IPEndPoint * address) const147   virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
148     NOTIMPLEMENTED();
149     return net::ERR_FAILED;
150   }
NetLog() const151   virtual const net::BoundNetLog& NetLog() const OVERRIDE {
152     NOTIMPLEMENTED();
153     return net_log_;
154   }
SetSubresourceSpeculation()155   virtual void SetSubresourceSpeculation() OVERRIDE {
156     NOTIMPLEMENTED();
157   }
SetOmniboxSpeculation()158   virtual void SetOmniboxSpeculation() OVERRIDE {
159     NOTIMPLEMENTED();
160   }
WasEverUsed() const161   virtual bool WasEverUsed() const OVERRIDE {
162     return true;
163   }
UsingTCPFastOpen() const164   virtual bool UsingTCPFastOpen() const OVERRIDE {
165     return false;
166   }
WasNpnNegotiated() const167   virtual bool WasNpnNegotiated() const OVERRIDE {
168     return false;
169   }
GetNegotiatedProtocol() const170   virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
171     return net::kProtoUnknown;
172   }
GetSSLInfo(net::SSLInfo * ssl_info)173   virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
174     NOTIMPLEMENTED();
175     return false;
176   }
177 
178  private:
179   MuxChannel* channel_;
180 
181   net::CompletionCallback read_callback_;
182   scoped_refptr<net::IOBuffer> read_buffer_;
183   int read_buffer_size_;
184 
185   bool write_pending_;
186   int write_result_;
187   net::CompletionCallback write_callback_;
188 
189   net::BoundNetLog net_log_;
190 
191   DISALLOW_COPY_AND_ASSIGN(MuxSocket);
192 };
193 
194 
MuxChannel(ChannelMultiplexer * multiplexer,const std::string & name,int send_id)195 ChannelMultiplexer::MuxChannel::MuxChannel(
196     ChannelMultiplexer* multiplexer,
197     const std::string& name,
198     int send_id)
199     : multiplexer_(multiplexer),
200       name_(name),
201       send_id_(send_id),
202       id_sent_(false),
203       receive_id_(kChannelIdUnknown),
204       socket_(NULL) {
205 }
206 
~MuxChannel()207 ChannelMultiplexer::MuxChannel::~MuxChannel() {
208   // Socket must be destroyed before the channel.
209   DCHECK(!socket_);
210   STLDeleteElements(&pending_packets_);
211 }
212 
CreateSocket()213 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
214   DCHECK(!socket_);  // Can't create more than one socket per channel.
215   scoped_ptr<MuxSocket> result(new MuxSocket(this));
216   socket_ = result.get();
217   return result.PassAs<net::StreamSocket>();
218 }
219 
OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)220 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
221     scoped_ptr<MultiplexPacket> packet,
222     const base::Closure& done_task) {
223   DCHECK_EQ(packet->channel_id(), receive_id_);
224   if (packet->data().size() > 0) {
225     pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
226     if (socket_) {
227       // Notify the socket that we have more data.
228       socket_->OnPacketReceived();
229     }
230   }
231 }
232 
OnWriteFailed()233 void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
234   if (socket_)
235     socket_->OnWriteFailed();
236 }
237 
OnSocketDestroyed()238 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
239   DCHECK(socket_);
240   socket_ = NULL;
241 }
242 
DoWrite(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)243 bool ChannelMultiplexer::MuxChannel::DoWrite(
244     scoped_ptr<MultiplexPacket> packet,
245     const base::Closure& done_task) {
246   packet->set_channel_id(send_id_);
247   if (!id_sent_) {
248     packet->set_channel_name(name_);
249     id_sent_ = true;
250   }
251   return multiplexer_->DoWrite(packet.Pass(), done_task);
252 }
253 
DoRead(net::IOBuffer * buffer,int buffer_len)254 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
255                                            int buffer_len) {
256   int pos = 0;
257   while (buffer_len > 0 && !pending_packets_.empty()) {
258     DCHECK(!pending_packets_.front()->is_empty());
259     int result = pending_packets_.front()->Read(
260         buffer->data() + pos, buffer_len);
261     DCHECK_LE(result, buffer_len);
262     pos += result;
263     buffer_len -= pos;
264     if (pending_packets_.front()->is_empty()) {
265       delete pending_packets_.front();
266       pending_packets_.erase(pending_packets_.begin());
267     }
268   }
269   return pos;
270 }
271 
MuxSocket(MuxChannel * channel)272 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
273     : channel_(channel),
274       read_buffer_size_(0),
275       write_pending_(false),
276       write_result_(0) {
277 }
278 
~MuxSocket()279 ChannelMultiplexer::MuxSocket::~MuxSocket() {
280   channel_->OnSocketDestroyed();
281 }
282 
Read(net::IOBuffer * buffer,int buffer_len,const net::CompletionCallback & callback)283 int ChannelMultiplexer::MuxSocket::Read(
284     net::IOBuffer* buffer, int buffer_len,
285     const net::CompletionCallback& callback) {
286   DCHECK(CalledOnValidThread());
287   DCHECK(read_callback_.is_null());
288 
289   int result = channel_->DoRead(buffer, buffer_len);
290   if (result == 0) {
291     read_buffer_ = buffer;
292     read_buffer_size_ = buffer_len;
293     read_callback_ = callback;
294     return net::ERR_IO_PENDING;
295   }
296   return result;
297 }
298 
Write(net::IOBuffer * buffer,int buffer_len,const net::CompletionCallback & callback)299 int ChannelMultiplexer::MuxSocket::Write(
300     net::IOBuffer* buffer, int buffer_len,
301     const net::CompletionCallback& callback) {
302   DCHECK(CalledOnValidThread());
303 
304   scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
305   size_t size = std::min(kMaxPacketSize, buffer_len);
306   packet->mutable_data()->assign(buffer->data(), size);
307 
308   write_pending_ = true;
309   bool result = channel_->DoWrite(packet.Pass(), base::Bind(
310       &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
311 
312   if (!result) {
313     // Cannot complete the write, e.g. if the connection has been terminated.
314     return net::ERR_FAILED;
315   }
316 
317   // OnWriteComplete() might be called above synchronously.
318   if (write_pending_) {
319     DCHECK(write_callback_.is_null());
320     write_callback_ = callback;
321     write_result_ = size;
322     return net::ERR_IO_PENDING;
323   }
324 
325   return size;
326 }
327 
OnWriteComplete()328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
329   write_pending_ = false;
330   if (!write_callback_.is_null()) {
331     net::CompletionCallback cb;
332     std::swap(cb, write_callback_);
333     cb.Run(write_result_);
334   }
335 }
336 
OnWriteFailed()337 void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
338   if (!write_callback_.is_null()) {
339     net::CompletionCallback cb;
340     std::swap(cb, write_callback_);
341     cb.Run(net::ERR_FAILED);
342   }
343 }
344 
OnPacketReceived()345 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
346   if (!read_callback_.is_null()) {
347     int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
348     read_buffer_ = NULL;
349     DCHECK_GT(result, 0);
350     net::CompletionCallback cb;
351     std::swap(cb, read_callback_);
352     cb.Run(result);
353   }
354 }
355 
ChannelMultiplexer(ChannelFactory * factory,const std::string & base_channel_name)356 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory,
357                                        const std::string& base_channel_name)
358     : base_channel_factory_(factory),
359       base_channel_name_(base_channel_name),
360       next_channel_id_(0),
361       weak_factory_(this) {
362 }
363 
~ChannelMultiplexer()364 ChannelMultiplexer::~ChannelMultiplexer() {
365   DCHECK(pending_channels_.empty());
366   STLDeleteValues(&channels_);
367 
368   // Cancel creation of the base channel if it hasn't finished.
369   if (base_channel_factory_)
370     base_channel_factory_->CancelChannelCreation(base_channel_name_);
371 }
372 
CreateStreamChannel(const std::string & name,const StreamChannelCallback & callback)373 void ChannelMultiplexer::CreateStreamChannel(
374     const std::string& name,
375     const StreamChannelCallback& callback) {
376   if (base_channel_.get()) {
377     // Already have |base_channel_|. Create new multiplexed channel
378     // synchronously.
379     callback.Run(GetOrCreateChannel(name)->CreateSocket());
380   } else if (!base_channel_.get() && !base_channel_factory_) {
381     // Fail synchronously if we failed to create |base_channel_|.
382     callback.Run(scoped_ptr<net::StreamSocket>());
383   } else {
384     // Still waiting for the |base_channel_|.
385     pending_channels_.push_back(PendingChannel(name, callback));
386 
387     // If this is the first multiplexed channel then create the base channel.
388     if (pending_channels_.size() == 1U) {
389       base_channel_factory_->CreateStreamChannel(
390           base_channel_name_,
391           base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
392                      base::Unretained(this)));
393     }
394   }
395 }
396 
CreateDatagramChannel(const std::string & name,const DatagramChannelCallback & callback)397 void ChannelMultiplexer::CreateDatagramChannel(
398     const std::string& name,
399     const DatagramChannelCallback& callback) {
400   NOTIMPLEMENTED();
401   callback.Run(scoped_ptr<net::Socket>());
402 }
403 
CancelChannelCreation(const std::string & name)404 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
405   for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
406        it != pending_channels_.end(); ++it) {
407     if (it->name == name) {
408       pending_channels_.erase(it);
409       return;
410     }
411   }
412 }
413 
OnBaseChannelReady(scoped_ptr<net::StreamSocket> socket)414 void ChannelMultiplexer::OnBaseChannelReady(
415     scoped_ptr<net::StreamSocket> socket) {
416   base_channel_factory_ = NULL;
417   base_channel_ = socket.Pass();
418 
419   if (base_channel_.get()) {
420     // Initialize reader and writer.
421     reader_.Init(base_channel_.get(),
422                  base::Bind(&ChannelMultiplexer::OnIncomingPacket,
423                             base::Unretained(this)));
424     writer_.Init(base_channel_.get(),
425                  base::Bind(&ChannelMultiplexer::OnWriteFailed,
426                             base::Unretained(this)));
427   }
428 
429   DoCreatePendingChannels();
430 }
431 
DoCreatePendingChannels()432 void ChannelMultiplexer::DoCreatePendingChannels() {
433   if (pending_channels_.empty())
434     return;
435 
436   // Every time this function is called it connects a single channel and posts a
437   // separate task to connect other channels. This is necessary because the
438   // callback may destroy the multiplexer or somehow else modify
439   // |pending_channels_| list (e.g. call CancelChannelCreation()).
440   base::ThreadTaskRunnerHandle::Get()->PostTask(
441       FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
442                             weak_factory_.GetWeakPtr()));
443 
444   PendingChannel c = pending_channels_.front();
445   pending_channels_.erase(pending_channels_.begin());
446   scoped_ptr<net::StreamSocket> socket;
447   if (base_channel_.get())
448     socket = GetOrCreateChannel(c.name)->CreateSocket();
449   c.callback.Run(socket.Pass());
450 }
451 
GetOrCreateChannel(const std::string & name)452 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
453     const std::string& name) {
454   // Check if we already have a channel with the requested name.
455   std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
456   if (it != channels_.end())
457     return it->second;
458 
459   // Create a new channel if we haven't found existing one.
460   MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
461   ++next_channel_id_;
462   channels_[channel->name()] = channel;
463   return channel;
464 }
465 
466 
OnWriteFailed(int error)467 void ChannelMultiplexer::OnWriteFailed(int error) {
468   for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
469        it != channels_.end(); ++it) {
470     base::ThreadTaskRunnerHandle::Get()->PostTask(
471         FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
472                               weak_factory_.GetWeakPtr(), it->second->name()));
473   }
474 }
475 
NotifyWriteFailed(const std::string & name)476 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
477   std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
478   if (it != channels_.end()) {
479     it->second->OnWriteFailed();
480   }
481 }
482 
OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)483 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
484                                           const base::Closure& done_task) {
485   if (!packet->has_channel_id()) {
486     LOG(ERROR) << "Received packet without channel_id.";
487     done_task.Run();
488     return;
489   }
490 
491   int receive_id = packet->channel_id();
492   MuxChannel* channel = NULL;
493   std::map<int, MuxChannel*>::iterator it =
494       channels_by_receive_id_.find(receive_id);
495   if (it != channels_by_receive_id_.end()) {
496     channel = it->second;
497   } else {
498     // This is a new |channel_id| we haven't seen before. Look it up by name.
499     if (!packet->has_channel_name()) {
500       LOG(ERROR) << "Received packet with unknown channel_id and "
501           "without channel_name.";
502       done_task.Run();
503       return;
504     }
505     channel = GetOrCreateChannel(packet->channel_name());
506     channel->set_receive_id(receive_id);
507     channels_by_receive_id_[receive_id] = channel;
508   }
509 
510   channel->OnIncomingPacket(packet.Pass(), done_task);
511 }
512 
DoWrite(scoped_ptr<MultiplexPacket> packet,const base::Closure & done_task)513 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
514                                  const base::Closure& done_task) {
515   return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
516 }
517 
518 }  // namespace protocol
519 }  // namespace remoting
520