• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "webrtc/base/natsocketfactory.h"
12 
13 #include "webrtc/base/logging.h"
14 #include "webrtc/base/natserver.h"
15 #include "webrtc/base/virtualsocketserver.h"
16 
17 namespace rtc {
18 
19 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
20 // format that the natserver uses.
21 // Returns 0 if an invalid address is passed.
PackAddressForNAT(char * buf,size_t buf_size,const SocketAddress & remote_addr)22 size_t PackAddressForNAT(char* buf, size_t buf_size,
23                          const SocketAddress& remote_addr) {
24   const IPAddress& ip = remote_addr.ipaddr();
25   int family = ip.family();
26   buf[0] = 0;
27   buf[1] = family;
28   // Writes the port.
29   *(reinterpret_cast<uint16*>(&buf[2])) = HostToNetwork16(remote_addr.port());
30   if (family == AF_INET) {
31     ASSERT(buf_size >= kNATEncodedIPv4AddressSize);
32     in_addr v4addr = ip.ipv4_address();
33     memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
34     return kNATEncodedIPv4AddressSize;
35   } else if (family == AF_INET6) {
36     ASSERT(buf_size >= kNATEncodedIPv6AddressSize);
37     in6_addr v6addr = ip.ipv6_address();
38     memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
39     return kNATEncodedIPv6AddressSize;
40   }
41   return 0U;
42 }
43 
44 // Decodes the remote address from a packet that has been encoded with the nat's
45 // quasi-STUN format. Returns the length of the address (i.e., the offset into
46 // data where the original packet starts).
UnpackAddressFromNAT(const char * buf,size_t buf_size,SocketAddress * remote_addr)47 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
48                             SocketAddress* remote_addr) {
49   ASSERT(buf_size >= 8);
50   ASSERT(buf[0] == 0);
51   int family = buf[1];
52   uint16 port = NetworkToHost16(*(reinterpret_cast<const uint16*>(&buf[2])));
53   if (family == AF_INET) {
54     const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
55     *remote_addr = SocketAddress(IPAddress(*v4addr), port);
56     return kNATEncodedIPv4AddressSize;
57   } else if (family == AF_INET6) {
58     ASSERT(buf_size >= 20);
59     const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
60     *remote_addr = SocketAddress(IPAddress(*v6addr), port);
61     return kNATEncodedIPv6AddressSize;
62   }
63   return 0U;
64 }
65 
66 
67 // NATSocket
68 class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
69  public:
NATSocket(NATInternalSocketFactory * sf,int family,int type)70   explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
71       : sf_(sf), family_(family), type_(type), connected_(false),
72         socket_(NULL), buf_(NULL), size_(0) {
73   }
74 
~NATSocket()75   virtual ~NATSocket() {
76     delete socket_;
77     delete[] buf_;
78   }
79 
GetLocalAddress() const80   virtual SocketAddress GetLocalAddress() const {
81     return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
82   }
83 
GetRemoteAddress() const84   virtual SocketAddress GetRemoteAddress() const {
85     return remote_addr_;  // will be NIL if not connected
86   }
87 
Bind(const SocketAddress & addr)88   virtual int Bind(const SocketAddress& addr) {
89     if (socket_) {  // already bound, bubble up error
90       return -1;
91     }
92 
93     int result;
94     socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
95     result = (socket_) ? socket_->Bind(addr) : -1;
96     if (result >= 0) {
97       socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
98       socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
99       socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
100       socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
101     } else {
102       server_addr_.Clear();
103       delete socket_;
104       socket_ = NULL;
105     }
106 
107     return result;
108   }
109 
Connect(const SocketAddress & addr)110   virtual int Connect(const SocketAddress& addr) {
111     if (!socket_) {  // socket must be bound, for now
112       return -1;
113     }
114 
115     int result = 0;
116     if (type_ == SOCK_STREAM) {
117       result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
118     } else {
119       connected_ = true;
120     }
121 
122     if (result >= 0) {
123       remote_addr_ = addr;
124     }
125 
126     return result;
127   }
128 
Send(const void * data,size_t size)129   virtual int Send(const void* data, size_t size) {
130     ASSERT(connected_);
131     return SendTo(data, size, remote_addr_);
132   }
133 
SendTo(const void * data,size_t size,const SocketAddress & addr)134   virtual int SendTo(const void* data, size_t size, const SocketAddress& addr) {
135     ASSERT(!connected_ || addr == remote_addr_);
136     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
137       return socket_->SendTo(data, size, addr);
138     }
139     // This array will be too large for IPv4 packets, but only by 12 bytes.
140     scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
141     size_t addrlength = PackAddressForNAT(buf.get(),
142                                           size + kNATEncodedIPv6AddressSize,
143                                           addr);
144     size_t encoded_size = size + addrlength;
145     memcpy(buf.get() + addrlength, data, size);
146     int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
147     if (result >= 0) {
148       ASSERT(result == static_cast<int>(encoded_size));
149       result = result - static_cast<int>(addrlength);
150     }
151     return result;
152   }
153 
Recv(void * data,size_t size)154   virtual int Recv(void* data, size_t size) {
155     SocketAddress addr;
156     return RecvFrom(data, size, &addr);
157   }
158 
RecvFrom(void * data,size_t size,SocketAddress * out_addr)159   virtual int RecvFrom(void* data, size_t size, SocketAddress *out_addr) {
160     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
161       return socket_->RecvFrom(data, size, out_addr);
162     }
163     // Make sure we have enough room to read the requested amount plus the
164     // largest possible header address.
165     SocketAddress remote_addr;
166     Grow(size + kNATEncodedIPv6AddressSize);
167 
168     // Read the packet from the socket.
169     int result = socket_->RecvFrom(buf_, size_, &remote_addr);
170     if (result >= 0) {
171       ASSERT(remote_addr == server_addr_);
172 
173       // TODO: we need better framing so we know how many bytes we can
174       // return before we need to read the next address. For UDP, this will be
175       // fine as long as the reader always reads everything in the packet.
176       ASSERT((size_t)result < size_);
177 
178       // Decode the wire packet into the actual results.
179       SocketAddress real_remote_addr;
180       size_t addrlength =
181           UnpackAddressFromNAT(buf_, result, &real_remote_addr);
182       memcpy(data, buf_ + addrlength, result - addrlength);
183 
184       // Make sure this packet should be delivered before returning it.
185       if (!connected_ || (real_remote_addr == remote_addr_)) {
186         if (out_addr)
187           *out_addr = real_remote_addr;
188         result = result - static_cast<int>(addrlength);
189       } else {
190         LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
191                       << real_remote_addr.ToString();
192         result = 0;  // Tell the caller we didn't read anything
193       }
194     }
195 
196     return result;
197   }
198 
Close()199   virtual int Close() {
200     int result = 0;
201     if (socket_) {
202       result = socket_->Close();
203       if (result >= 0) {
204         connected_ = false;
205         remote_addr_ = SocketAddress();
206         delete socket_;
207         socket_ = NULL;
208       }
209     }
210     return result;
211   }
212 
Listen(int backlog)213   virtual int Listen(int backlog) {
214     return socket_->Listen(backlog);
215   }
Accept(SocketAddress * paddr)216   virtual AsyncSocket* Accept(SocketAddress *paddr) {
217     return socket_->Accept(paddr);
218   }
GetError() const219   virtual int GetError() const {
220     return socket_->GetError();
221   }
SetError(int error)222   virtual void SetError(int error) {
223     socket_->SetError(error);
224   }
GetState() const225   virtual ConnState GetState() const {
226     return connected_ ? CS_CONNECTED : CS_CLOSED;
227   }
EstimateMTU(uint16 * mtu)228   virtual int EstimateMTU(uint16* mtu) {
229     return socket_->EstimateMTU(mtu);
230   }
GetOption(Option opt,int * value)231   virtual int GetOption(Option opt, int* value) {
232     return socket_->GetOption(opt, value);
233   }
SetOption(Option opt,int value)234   virtual int SetOption(Option opt, int value) {
235     return socket_->SetOption(opt, value);
236   }
237 
OnConnectEvent(AsyncSocket * socket)238   void OnConnectEvent(AsyncSocket* socket) {
239     // If we're NATed, we need to send a request with the real addr to use.
240     ASSERT(socket == socket_);
241     if (server_addr_.IsNil()) {
242       connected_ = true;
243       SignalConnectEvent(this);
244     } else {
245       SendConnectRequest();
246     }
247   }
OnReadEvent(AsyncSocket * socket)248   void OnReadEvent(AsyncSocket* socket) {
249     // If we're NATed, we need to process the connect reply.
250     ASSERT(socket == socket_);
251     if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
252       HandleConnectReply();
253     } else {
254       SignalReadEvent(this);
255     }
256   }
OnWriteEvent(AsyncSocket * socket)257   void OnWriteEvent(AsyncSocket* socket) {
258     ASSERT(socket == socket_);
259     SignalWriteEvent(this);
260   }
OnCloseEvent(AsyncSocket * socket,int error)261   void OnCloseEvent(AsyncSocket* socket, int error) {
262     ASSERT(socket == socket_);
263     SignalCloseEvent(this, error);
264   }
265 
266  private:
267   // Makes sure the buffer is at least the given size.
Grow(size_t new_size)268   void Grow(size_t new_size) {
269     if (size_ < new_size) {
270       delete[] buf_;
271       size_ = new_size;
272       buf_ = new char[size_];
273     }
274   }
275 
276   // Sends the destination address to the server to tell it to connect.
SendConnectRequest()277   void SendConnectRequest() {
278     char buf[256];
279     size_t length = PackAddressForNAT(buf, ARRAY_SIZE(buf), remote_addr_);
280     socket_->Send(buf, length);
281   }
282 
283   // Handles the byte sent back from the server and fires the appropriate event.
HandleConnectReply()284   void HandleConnectReply() {
285     char code;
286     socket_->Recv(&code, sizeof(code));
287     if (code == 0) {
288       SignalConnectEvent(this);
289     } else {
290       Close();
291       SignalCloseEvent(this, code);
292     }
293   }
294 
295   NATInternalSocketFactory* sf_;
296   int family_;
297   int type_;
298   bool connected_;
299   SocketAddress remote_addr_;
300   SocketAddress server_addr_;  // address of the NAT server
301   AsyncSocket* socket_;
302   char* buf_;
303   size_t size_;
304 };
305 
306 // NATSocketFactory
NATSocketFactory(SocketFactory * factory,const SocketAddress & nat_addr)307 NATSocketFactory::NATSocketFactory(SocketFactory* factory,
308                                    const SocketAddress& nat_addr)
309     : factory_(factory), nat_addr_(nat_addr) {
310 }
311 
CreateSocket(int type)312 Socket* NATSocketFactory::CreateSocket(int type) {
313   return CreateSocket(AF_INET, type);
314 }
315 
CreateSocket(int family,int type)316 Socket* NATSocketFactory::CreateSocket(int family, int type) {
317   return new NATSocket(this, family, type);
318 }
319 
CreateAsyncSocket(int type)320 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
321   return CreateAsyncSocket(AF_INET, type);
322 }
323 
CreateAsyncSocket(int family,int type)324 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
325   return new NATSocket(this, family, type);
326 }
327 
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)328 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
329     const SocketAddress& local_addr, SocketAddress* nat_addr) {
330   *nat_addr = nat_addr_;
331   return factory_->CreateAsyncSocket(family, type);
332 }
333 
334 // NATSocketServer
NATSocketServer(SocketServer * server)335 NATSocketServer::NATSocketServer(SocketServer* server)
336     : server_(server), msg_queue_(NULL) {
337 }
338 
GetTranslator(const SocketAddress & ext_ip)339 NATSocketServer::Translator* NATSocketServer::GetTranslator(
340     const SocketAddress& ext_ip) {
341   return nats_.Get(ext_ip);
342 }
343 
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)344 NATSocketServer::Translator* NATSocketServer::AddTranslator(
345     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
346   // Fail if a translator already exists with this extternal address.
347   if (nats_.Get(ext_ip))
348     return NULL;
349 
350   return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
351 }
352 
RemoveTranslator(const SocketAddress & ext_ip)353 void NATSocketServer::RemoveTranslator(
354     const SocketAddress& ext_ip) {
355   nats_.Remove(ext_ip);
356 }
357 
CreateSocket(int type)358 Socket* NATSocketServer::CreateSocket(int type) {
359   return CreateSocket(AF_INET, type);
360 }
361 
CreateSocket(int family,int type)362 Socket* NATSocketServer::CreateSocket(int family, int type) {
363   return new NATSocket(this, family, type);
364 }
365 
CreateAsyncSocket(int type)366 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
367   return CreateAsyncSocket(AF_INET, type);
368 }
369 
CreateAsyncSocket(int family,int type)370 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
371   return new NATSocket(this, family, type);
372 }
373 
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)374 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
375     const SocketAddress& local_addr, SocketAddress* nat_addr) {
376   AsyncSocket* socket = NULL;
377   Translator* nat = nats_.FindClient(local_addr);
378   if (nat) {
379     socket = nat->internal_factory()->CreateAsyncSocket(family, type);
380     *nat_addr = (type == SOCK_STREAM) ?
381         nat->internal_tcp_address() : nat->internal_address();
382   } else {
383     socket = server_->CreateAsyncSocket(family, type);
384   }
385   return socket;
386 }
387 
388 // NATSocketServer::Translator
Translator(NATSocketServer * server,NATType type,const SocketAddress & int_ip,SocketFactory * ext_factory,const SocketAddress & ext_ip)389 NATSocketServer::Translator::Translator(
390     NATSocketServer* server, NATType type, const SocketAddress& int_ip,
391     SocketFactory* ext_factory, const SocketAddress& ext_ip)
392     : server_(server) {
393   // Create a new private network, and a NATServer running on the private
394   // network that bridges to the external network. Also tell the private
395   // network to use the same message queue as us.
396   VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
397   internal_server->SetMessageQueue(server_->queue());
398   internal_factory_.reset(internal_server);
399   nat_server_.reset(new NATServer(type, internal_server, int_ip,
400                                   ext_factory, ext_ip));
401 }
402 
403 
GetTranslator(const SocketAddress & ext_ip)404 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
405     const SocketAddress& ext_ip) {
406   return nats_.Get(ext_ip);
407 }
408 
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)409 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
410     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
411   // Fail if a translator already exists with this extternal address.
412   if (nats_.Get(ext_ip))
413     return NULL;
414 
415   AddClient(ext_ip);
416   return nats_.Add(ext_ip,
417                    new Translator(server_, type, int_ip, server_, ext_ip));
418 }
RemoveTranslator(const SocketAddress & ext_ip)419 void NATSocketServer::Translator::RemoveTranslator(
420     const SocketAddress& ext_ip) {
421   nats_.Remove(ext_ip);
422   RemoveClient(ext_ip);
423 }
424 
AddClient(const SocketAddress & int_ip)425 bool NATSocketServer::Translator::AddClient(
426     const SocketAddress& int_ip) {
427   // Fail if a client already exists with this internal address.
428   if (clients_.find(int_ip) != clients_.end())
429     return false;
430 
431   clients_.insert(int_ip);
432   return true;
433 }
434 
RemoveClient(const SocketAddress & int_ip)435 void NATSocketServer::Translator::RemoveClient(
436     const SocketAddress& int_ip) {
437   std::set<SocketAddress>::iterator it = clients_.find(int_ip);
438   if (it != clients_.end()) {
439     clients_.erase(it);
440   }
441 }
442 
FindClient(const SocketAddress & int_ip)443 NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
444     const SocketAddress& int_ip) {
445   // See if we have the requested IP, or any of our children do.
446   return (clients_.find(int_ip) != clients_.end()) ?
447       this : nats_.FindClient(int_ip);
448 }
449 
450 // NATSocketServer::TranslatorMap
~TranslatorMap()451 NATSocketServer::TranslatorMap::~TranslatorMap() {
452   for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
453     delete it->second;
454   }
455 }
456 
Get(const SocketAddress & ext_ip)457 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
458     const SocketAddress& ext_ip) {
459   TranslatorMap::iterator it = find(ext_ip);
460   return (it != end()) ? it->second : NULL;
461 }
462 
Add(const SocketAddress & ext_ip,Translator * nat)463 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
464     const SocketAddress& ext_ip, Translator* nat) {
465   (*this)[ext_ip] = nat;
466   return nat;
467 }
468 
Remove(const SocketAddress & ext_ip)469 void NATSocketServer::TranslatorMap::Remove(
470     const SocketAddress& ext_ip) {
471   TranslatorMap::iterator it = find(ext_ip);
472   if (it != end()) {
473     delete it->second;
474     erase(it);
475   }
476 }
477 
FindClient(const SocketAddress & int_ip)478 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
479     const SocketAddress& int_ip) {
480   Translator* nat = NULL;
481   for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
482     nat = it->second->FindClient(int_ip);
483   }
484   return nat;
485 }
486 
487 }  // namespace rtc
488