1 /*
2 * Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "webrtc/base/natsocketfactory.h"
12
13 #include "webrtc/base/logging.h"
14 #include "webrtc/base/natserver.h"
15 #include "webrtc/base/virtualsocketserver.h"
16
17 namespace rtc {
18
19 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
20 // format that the natserver uses.
21 // Returns 0 if an invalid address is passed.
PackAddressForNAT(char * buf,size_t buf_size,const SocketAddress & remote_addr)22 size_t PackAddressForNAT(char* buf, size_t buf_size,
23 const SocketAddress& remote_addr) {
24 const IPAddress& ip = remote_addr.ipaddr();
25 int family = ip.family();
26 buf[0] = 0;
27 buf[1] = family;
28 // Writes the port.
29 *(reinterpret_cast<uint16*>(&buf[2])) = HostToNetwork16(remote_addr.port());
30 if (family == AF_INET) {
31 ASSERT(buf_size >= kNATEncodedIPv4AddressSize);
32 in_addr v4addr = ip.ipv4_address();
33 memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
34 return kNATEncodedIPv4AddressSize;
35 } else if (family == AF_INET6) {
36 ASSERT(buf_size >= kNATEncodedIPv6AddressSize);
37 in6_addr v6addr = ip.ipv6_address();
38 memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
39 return kNATEncodedIPv6AddressSize;
40 }
41 return 0U;
42 }
43
44 // Decodes the remote address from a packet that has been encoded with the nat's
45 // quasi-STUN format. Returns the length of the address (i.e., the offset into
46 // data where the original packet starts).
UnpackAddressFromNAT(const char * buf,size_t buf_size,SocketAddress * remote_addr)47 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
48 SocketAddress* remote_addr) {
49 ASSERT(buf_size >= 8);
50 ASSERT(buf[0] == 0);
51 int family = buf[1];
52 uint16 port = NetworkToHost16(*(reinterpret_cast<const uint16*>(&buf[2])));
53 if (family == AF_INET) {
54 const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
55 *remote_addr = SocketAddress(IPAddress(*v4addr), port);
56 return kNATEncodedIPv4AddressSize;
57 } else if (family == AF_INET6) {
58 ASSERT(buf_size >= 20);
59 const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
60 *remote_addr = SocketAddress(IPAddress(*v6addr), port);
61 return kNATEncodedIPv6AddressSize;
62 }
63 return 0U;
64 }
65
66
67 // NATSocket
68 class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
69 public:
NATSocket(NATInternalSocketFactory * sf,int family,int type)70 explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
71 : sf_(sf), family_(family), type_(type), connected_(false),
72 socket_(NULL), buf_(NULL), size_(0) {
73 }
74
~NATSocket()75 virtual ~NATSocket() {
76 delete socket_;
77 delete[] buf_;
78 }
79
GetLocalAddress() const80 virtual SocketAddress GetLocalAddress() const {
81 return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
82 }
83
GetRemoteAddress() const84 virtual SocketAddress GetRemoteAddress() const {
85 return remote_addr_; // will be NIL if not connected
86 }
87
Bind(const SocketAddress & addr)88 virtual int Bind(const SocketAddress& addr) {
89 if (socket_) { // already bound, bubble up error
90 return -1;
91 }
92
93 int result;
94 socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
95 result = (socket_) ? socket_->Bind(addr) : -1;
96 if (result >= 0) {
97 socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
98 socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
99 socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
100 socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
101 } else {
102 server_addr_.Clear();
103 delete socket_;
104 socket_ = NULL;
105 }
106
107 return result;
108 }
109
Connect(const SocketAddress & addr)110 virtual int Connect(const SocketAddress& addr) {
111 if (!socket_) { // socket must be bound, for now
112 return -1;
113 }
114
115 int result = 0;
116 if (type_ == SOCK_STREAM) {
117 result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
118 } else {
119 connected_ = true;
120 }
121
122 if (result >= 0) {
123 remote_addr_ = addr;
124 }
125
126 return result;
127 }
128
Send(const void * data,size_t size)129 virtual int Send(const void* data, size_t size) {
130 ASSERT(connected_);
131 return SendTo(data, size, remote_addr_);
132 }
133
SendTo(const void * data,size_t size,const SocketAddress & addr)134 virtual int SendTo(const void* data, size_t size, const SocketAddress& addr) {
135 ASSERT(!connected_ || addr == remote_addr_);
136 if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
137 return socket_->SendTo(data, size, addr);
138 }
139 // This array will be too large for IPv4 packets, but only by 12 bytes.
140 scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
141 size_t addrlength = PackAddressForNAT(buf.get(),
142 size + kNATEncodedIPv6AddressSize,
143 addr);
144 size_t encoded_size = size + addrlength;
145 memcpy(buf.get() + addrlength, data, size);
146 int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
147 if (result >= 0) {
148 ASSERT(result == static_cast<int>(encoded_size));
149 result = result - static_cast<int>(addrlength);
150 }
151 return result;
152 }
153
Recv(void * data,size_t size)154 virtual int Recv(void* data, size_t size) {
155 SocketAddress addr;
156 return RecvFrom(data, size, &addr);
157 }
158
RecvFrom(void * data,size_t size,SocketAddress * out_addr)159 virtual int RecvFrom(void* data, size_t size, SocketAddress *out_addr) {
160 if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
161 return socket_->RecvFrom(data, size, out_addr);
162 }
163 // Make sure we have enough room to read the requested amount plus the
164 // largest possible header address.
165 SocketAddress remote_addr;
166 Grow(size + kNATEncodedIPv6AddressSize);
167
168 // Read the packet from the socket.
169 int result = socket_->RecvFrom(buf_, size_, &remote_addr);
170 if (result >= 0) {
171 ASSERT(remote_addr == server_addr_);
172
173 // TODO: we need better framing so we know how many bytes we can
174 // return before we need to read the next address. For UDP, this will be
175 // fine as long as the reader always reads everything in the packet.
176 ASSERT((size_t)result < size_);
177
178 // Decode the wire packet into the actual results.
179 SocketAddress real_remote_addr;
180 size_t addrlength =
181 UnpackAddressFromNAT(buf_, result, &real_remote_addr);
182 memcpy(data, buf_ + addrlength, result - addrlength);
183
184 // Make sure this packet should be delivered before returning it.
185 if (!connected_ || (real_remote_addr == remote_addr_)) {
186 if (out_addr)
187 *out_addr = real_remote_addr;
188 result = result - static_cast<int>(addrlength);
189 } else {
190 LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
191 << real_remote_addr.ToString();
192 result = 0; // Tell the caller we didn't read anything
193 }
194 }
195
196 return result;
197 }
198
Close()199 virtual int Close() {
200 int result = 0;
201 if (socket_) {
202 result = socket_->Close();
203 if (result >= 0) {
204 connected_ = false;
205 remote_addr_ = SocketAddress();
206 delete socket_;
207 socket_ = NULL;
208 }
209 }
210 return result;
211 }
212
Listen(int backlog)213 virtual int Listen(int backlog) {
214 return socket_->Listen(backlog);
215 }
Accept(SocketAddress * paddr)216 virtual AsyncSocket* Accept(SocketAddress *paddr) {
217 return socket_->Accept(paddr);
218 }
GetError() const219 virtual int GetError() const {
220 return socket_->GetError();
221 }
SetError(int error)222 virtual void SetError(int error) {
223 socket_->SetError(error);
224 }
GetState() const225 virtual ConnState GetState() const {
226 return connected_ ? CS_CONNECTED : CS_CLOSED;
227 }
EstimateMTU(uint16 * mtu)228 virtual int EstimateMTU(uint16* mtu) {
229 return socket_->EstimateMTU(mtu);
230 }
GetOption(Option opt,int * value)231 virtual int GetOption(Option opt, int* value) {
232 return socket_->GetOption(opt, value);
233 }
SetOption(Option opt,int value)234 virtual int SetOption(Option opt, int value) {
235 return socket_->SetOption(opt, value);
236 }
237
OnConnectEvent(AsyncSocket * socket)238 void OnConnectEvent(AsyncSocket* socket) {
239 // If we're NATed, we need to send a request with the real addr to use.
240 ASSERT(socket == socket_);
241 if (server_addr_.IsNil()) {
242 connected_ = true;
243 SignalConnectEvent(this);
244 } else {
245 SendConnectRequest();
246 }
247 }
OnReadEvent(AsyncSocket * socket)248 void OnReadEvent(AsyncSocket* socket) {
249 // If we're NATed, we need to process the connect reply.
250 ASSERT(socket == socket_);
251 if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
252 HandleConnectReply();
253 } else {
254 SignalReadEvent(this);
255 }
256 }
OnWriteEvent(AsyncSocket * socket)257 void OnWriteEvent(AsyncSocket* socket) {
258 ASSERT(socket == socket_);
259 SignalWriteEvent(this);
260 }
OnCloseEvent(AsyncSocket * socket,int error)261 void OnCloseEvent(AsyncSocket* socket, int error) {
262 ASSERT(socket == socket_);
263 SignalCloseEvent(this, error);
264 }
265
266 private:
267 // Makes sure the buffer is at least the given size.
Grow(size_t new_size)268 void Grow(size_t new_size) {
269 if (size_ < new_size) {
270 delete[] buf_;
271 size_ = new_size;
272 buf_ = new char[size_];
273 }
274 }
275
276 // Sends the destination address to the server to tell it to connect.
SendConnectRequest()277 void SendConnectRequest() {
278 char buf[256];
279 size_t length = PackAddressForNAT(buf, ARRAY_SIZE(buf), remote_addr_);
280 socket_->Send(buf, length);
281 }
282
283 // Handles the byte sent back from the server and fires the appropriate event.
HandleConnectReply()284 void HandleConnectReply() {
285 char code;
286 socket_->Recv(&code, sizeof(code));
287 if (code == 0) {
288 SignalConnectEvent(this);
289 } else {
290 Close();
291 SignalCloseEvent(this, code);
292 }
293 }
294
295 NATInternalSocketFactory* sf_;
296 int family_;
297 int type_;
298 bool connected_;
299 SocketAddress remote_addr_;
300 SocketAddress server_addr_; // address of the NAT server
301 AsyncSocket* socket_;
302 char* buf_;
303 size_t size_;
304 };
305
306 // NATSocketFactory
NATSocketFactory(SocketFactory * factory,const SocketAddress & nat_addr)307 NATSocketFactory::NATSocketFactory(SocketFactory* factory,
308 const SocketAddress& nat_addr)
309 : factory_(factory), nat_addr_(nat_addr) {
310 }
311
CreateSocket(int type)312 Socket* NATSocketFactory::CreateSocket(int type) {
313 return CreateSocket(AF_INET, type);
314 }
315
CreateSocket(int family,int type)316 Socket* NATSocketFactory::CreateSocket(int family, int type) {
317 return new NATSocket(this, family, type);
318 }
319
CreateAsyncSocket(int type)320 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
321 return CreateAsyncSocket(AF_INET, type);
322 }
323
CreateAsyncSocket(int family,int type)324 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
325 return new NATSocket(this, family, type);
326 }
327
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)328 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
329 const SocketAddress& local_addr, SocketAddress* nat_addr) {
330 *nat_addr = nat_addr_;
331 return factory_->CreateAsyncSocket(family, type);
332 }
333
334 // NATSocketServer
NATSocketServer(SocketServer * server)335 NATSocketServer::NATSocketServer(SocketServer* server)
336 : server_(server), msg_queue_(NULL) {
337 }
338
GetTranslator(const SocketAddress & ext_ip)339 NATSocketServer::Translator* NATSocketServer::GetTranslator(
340 const SocketAddress& ext_ip) {
341 return nats_.Get(ext_ip);
342 }
343
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)344 NATSocketServer::Translator* NATSocketServer::AddTranslator(
345 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
346 // Fail if a translator already exists with this extternal address.
347 if (nats_.Get(ext_ip))
348 return NULL;
349
350 return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
351 }
352
RemoveTranslator(const SocketAddress & ext_ip)353 void NATSocketServer::RemoveTranslator(
354 const SocketAddress& ext_ip) {
355 nats_.Remove(ext_ip);
356 }
357
CreateSocket(int type)358 Socket* NATSocketServer::CreateSocket(int type) {
359 return CreateSocket(AF_INET, type);
360 }
361
CreateSocket(int family,int type)362 Socket* NATSocketServer::CreateSocket(int family, int type) {
363 return new NATSocket(this, family, type);
364 }
365
CreateAsyncSocket(int type)366 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
367 return CreateAsyncSocket(AF_INET, type);
368 }
369
CreateAsyncSocket(int family,int type)370 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
371 return new NATSocket(this, family, type);
372 }
373
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)374 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
375 const SocketAddress& local_addr, SocketAddress* nat_addr) {
376 AsyncSocket* socket = NULL;
377 Translator* nat = nats_.FindClient(local_addr);
378 if (nat) {
379 socket = nat->internal_factory()->CreateAsyncSocket(family, type);
380 *nat_addr = (type == SOCK_STREAM) ?
381 nat->internal_tcp_address() : nat->internal_address();
382 } else {
383 socket = server_->CreateAsyncSocket(family, type);
384 }
385 return socket;
386 }
387
388 // NATSocketServer::Translator
Translator(NATSocketServer * server,NATType type,const SocketAddress & int_ip,SocketFactory * ext_factory,const SocketAddress & ext_ip)389 NATSocketServer::Translator::Translator(
390 NATSocketServer* server, NATType type, const SocketAddress& int_ip,
391 SocketFactory* ext_factory, const SocketAddress& ext_ip)
392 : server_(server) {
393 // Create a new private network, and a NATServer running on the private
394 // network that bridges to the external network. Also tell the private
395 // network to use the same message queue as us.
396 VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
397 internal_server->SetMessageQueue(server_->queue());
398 internal_factory_.reset(internal_server);
399 nat_server_.reset(new NATServer(type, internal_server, int_ip,
400 ext_factory, ext_ip));
401 }
402
403
GetTranslator(const SocketAddress & ext_ip)404 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
405 const SocketAddress& ext_ip) {
406 return nats_.Get(ext_ip);
407 }
408
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)409 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
410 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
411 // Fail if a translator already exists with this extternal address.
412 if (nats_.Get(ext_ip))
413 return NULL;
414
415 AddClient(ext_ip);
416 return nats_.Add(ext_ip,
417 new Translator(server_, type, int_ip, server_, ext_ip));
418 }
RemoveTranslator(const SocketAddress & ext_ip)419 void NATSocketServer::Translator::RemoveTranslator(
420 const SocketAddress& ext_ip) {
421 nats_.Remove(ext_ip);
422 RemoveClient(ext_ip);
423 }
424
AddClient(const SocketAddress & int_ip)425 bool NATSocketServer::Translator::AddClient(
426 const SocketAddress& int_ip) {
427 // Fail if a client already exists with this internal address.
428 if (clients_.find(int_ip) != clients_.end())
429 return false;
430
431 clients_.insert(int_ip);
432 return true;
433 }
434
RemoveClient(const SocketAddress & int_ip)435 void NATSocketServer::Translator::RemoveClient(
436 const SocketAddress& int_ip) {
437 std::set<SocketAddress>::iterator it = clients_.find(int_ip);
438 if (it != clients_.end()) {
439 clients_.erase(it);
440 }
441 }
442
FindClient(const SocketAddress & int_ip)443 NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
444 const SocketAddress& int_ip) {
445 // See if we have the requested IP, or any of our children do.
446 return (clients_.find(int_ip) != clients_.end()) ?
447 this : nats_.FindClient(int_ip);
448 }
449
450 // NATSocketServer::TranslatorMap
~TranslatorMap()451 NATSocketServer::TranslatorMap::~TranslatorMap() {
452 for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
453 delete it->second;
454 }
455 }
456
Get(const SocketAddress & ext_ip)457 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
458 const SocketAddress& ext_ip) {
459 TranslatorMap::iterator it = find(ext_ip);
460 return (it != end()) ? it->second : NULL;
461 }
462
Add(const SocketAddress & ext_ip,Translator * nat)463 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
464 const SocketAddress& ext_ip, Translator* nat) {
465 (*this)[ext_ip] = nat;
466 return nat;
467 }
468
Remove(const SocketAddress & ext_ip)469 void NATSocketServer::TranslatorMap::Remove(
470 const SocketAddress& ext_ip) {
471 TranslatorMap::iterator it = find(ext_ip);
472 if (it != end()) {
473 delete it->second;
474 erase(it);
475 }
476 }
477
FindClient(const SocketAddress & int_ip)478 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
479 const SocketAddress& int_ip) {
480 Translator* nat = NULL;
481 for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
482 nat = it->second->FindClient(int_ip);
483 }
484 return nat;
485 }
486
487 } // namespace rtc
488