1 /*
2 * Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "webrtc/base/natsocketfactory.h"
12
13 #include "webrtc/base/arraysize.h"
14 #include "webrtc/base/logging.h"
15 #include "webrtc/base/natserver.h"
16 #include "webrtc/base/virtualsocketserver.h"
17
18 namespace rtc {
19
20 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
21 // format that the natserver uses.
22 // Returns 0 if an invalid address is passed.
PackAddressForNAT(char * buf,size_t buf_size,const SocketAddress & remote_addr)23 size_t PackAddressForNAT(char* buf, size_t buf_size,
24 const SocketAddress& remote_addr) {
25 const IPAddress& ip = remote_addr.ipaddr();
26 int family = ip.family();
27 buf[0] = 0;
28 buf[1] = family;
29 // Writes the port.
30 *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
31 if (family == AF_INET) {
32 ASSERT(buf_size >= kNATEncodedIPv4AddressSize);
33 in_addr v4addr = ip.ipv4_address();
34 memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
35 return kNATEncodedIPv4AddressSize;
36 } else if (family == AF_INET6) {
37 ASSERT(buf_size >= kNATEncodedIPv6AddressSize);
38 in6_addr v6addr = ip.ipv6_address();
39 memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
40 return kNATEncodedIPv6AddressSize;
41 }
42 return 0U;
43 }
44
45 // Decodes the remote address from a packet that has been encoded with the nat's
46 // quasi-STUN format. Returns the length of the address (i.e., the offset into
47 // data where the original packet starts).
UnpackAddressFromNAT(const char * buf,size_t buf_size,SocketAddress * remote_addr)48 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
49 SocketAddress* remote_addr) {
50 ASSERT(buf_size >= 8);
51 ASSERT(buf[0] == 0);
52 int family = buf[1];
53 uint16_t port =
54 NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
55 if (family == AF_INET) {
56 const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
57 *remote_addr = SocketAddress(IPAddress(*v4addr), port);
58 return kNATEncodedIPv4AddressSize;
59 } else if (family == AF_INET6) {
60 ASSERT(buf_size >= 20);
61 const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
62 *remote_addr = SocketAddress(IPAddress(*v6addr), port);
63 return kNATEncodedIPv6AddressSize;
64 }
65 return 0U;
66 }
67
68
69 // NATSocket
70 class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
71 public:
NATSocket(NATInternalSocketFactory * sf,int family,int type)72 explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
73 : sf_(sf), family_(family), type_(type), connected_(false),
74 socket_(NULL), buf_(NULL), size_(0) {
75 }
76
~NATSocket()77 ~NATSocket() override {
78 delete socket_;
79 delete[] buf_;
80 }
81
GetLocalAddress() const82 SocketAddress GetLocalAddress() const override {
83 return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
84 }
85
GetRemoteAddress() const86 SocketAddress GetRemoteAddress() const override {
87 return remote_addr_; // will be NIL if not connected
88 }
89
Bind(const SocketAddress & addr)90 int Bind(const SocketAddress& addr) override {
91 if (socket_) { // already bound, bubble up error
92 return -1;
93 }
94
95 int result;
96 socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
97 result = (socket_) ? socket_->Bind(addr) : -1;
98 if (result >= 0) {
99 socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
100 socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
101 socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
102 socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
103 } else {
104 server_addr_.Clear();
105 delete socket_;
106 socket_ = NULL;
107 }
108
109 return result;
110 }
111
Connect(const SocketAddress & addr)112 int Connect(const SocketAddress& addr) override {
113 if (!socket_) { // socket must be bound, for now
114 return -1;
115 }
116
117 int result = 0;
118 if (type_ == SOCK_STREAM) {
119 result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
120 } else {
121 connected_ = true;
122 }
123
124 if (result >= 0) {
125 remote_addr_ = addr;
126 }
127
128 return result;
129 }
130
Send(const void * data,size_t size)131 int Send(const void* data, size_t size) override {
132 ASSERT(connected_);
133 return SendTo(data, size, remote_addr_);
134 }
135
SendTo(const void * data,size_t size,const SocketAddress & addr)136 int SendTo(const void* data,
137 size_t size,
138 const SocketAddress& addr) override {
139 ASSERT(!connected_ || addr == remote_addr_);
140 if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
141 return socket_->SendTo(data, size, addr);
142 }
143 // This array will be too large for IPv4 packets, but only by 12 bytes.
144 scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
145 size_t addrlength = PackAddressForNAT(buf.get(),
146 size + kNATEncodedIPv6AddressSize,
147 addr);
148 size_t encoded_size = size + addrlength;
149 memcpy(buf.get() + addrlength, data, size);
150 int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
151 if (result >= 0) {
152 ASSERT(result == static_cast<int>(encoded_size));
153 result = result - static_cast<int>(addrlength);
154 }
155 return result;
156 }
157
Recv(void * data,size_t size)158 int Recv(void* data, size_t size) override {
159 SocketAddress addr;
160 return RecvFrom(data, size, &addr);
161 }
162
RecvFrom(void * data,size_t size,SocketAddress * out_addr)163 int RecvFrom(void* data, size_t size, SocketAddress* out_addr) override {
164 if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
165 return socket_->RecvFrom(data, size, out_addr);
166 }
167 // Make sure we have enough room to read the requested amount plus the
168 // largest possible header address.
169 SocketAddress remote_addr;
170 Grow(size + kNATEncodedIPv6AddressSize);
171
172 // Read the packet from the socket.
173 int result = socket_->RecvFrom(buf_, size_, &remote_addr);
174 if (result >= 0) {
175 ASSERT(remote_addr == server_addr_);
176
177 // TODO: we need better framing so we know how many bytes we can
178 // return before we need to read the next address. For UDP, this will be
179 // fine as long as the reader always reads everything in the packet.
180 ASSERT((size_t)result < size_);
181
182 // Decode the wire packet into the actual results.
183 SocketAddress real_remote_addr;
184 size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
185 memcpy(data, buf_ + addrlength, result - addrlength);
186
187 // Make sure this packet should be delivered before returning it.
188 if (!connected_ || (real_remote_addr == remote_addr_)) {
189 if (out_addr)
190 *out_addr = real_remote_addr;
191 result = result - static_cast<int>(addrlength);
192 } else {
193 LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
194 << real_remote_addr.ToString();
195 result = 0; // Tell the caller we didn't read anything
196 }
197 }
198
199 return result;
200 }
201
Close()202 int Close() override {
203 int result = 0;
204 if (socket_) {
205 result = socket_->Close();
206 if (result >= 0) {
207 connected_ = false;
208 remote_addr_ = SocketAddress();
209 delete socket_;
210 socket_ = NULL;
211 }
212 }
213 return result;
214 }
215
Listen(int backlog)216 int Listen(int backlog) override { return socket_->Listen(backlog); }
Accept(SocketAddress * paddr)217 AsyncSocket* Accept(SocketAddress* paddr) override {
218 return socket_->Accept(paddr);
219 }
GetError() const220 int GetError() const override { return socket_->GetError(); }
SetError(int error)221 void SetError(int error) override { socket_->SetError(error); }
GetState() const222 ConnState GetState() const override {
223 return connected_ ? CS_CONNECTED : CS_CLOSED;
224 }
EstimateMTU(uint16_t * mtu)225 int EstimateMTU(uint16_t* mtu) override { return socket_->EstimateMTU(mtu); }
GetOption(Option opt,int * value)226 int GetOption(Option opt, int* value) override {
227 return socket_->GetOption(opt, value);
228 }
SetOption(Option opt,int value)229 int SetOption(Option opt, int value) override {
230 return socket_->SetOption(opt, value);
231 }
232
OnConnectEvent(AsyncSocket * socket)233 void OnConnectEvent(AsyncSocket* socket) {
234 // If we're NATed, we need to send a message with the real addr to use.
235 ASSERT(socket == socket_);
236 if (server_addr_.IsNil()) {
237 connected_ = true;
238 SignalConnectEvent(this);
239 } else {
240 SendConnectRequest();
241 }
242 }
OnReadEvent(AsyncSocket * socket)243 void OnReadEvent(AsyncSocket* socket) {
244 // If we're NATed, we need to process the connect reply.
245 ASSERT(socket == socket_);
246 if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
247 HandleConnectReply();
248 } else {
249 SignalReadEvent(this);
250 }
251 }
OnWriteEvent(AsyncSocket * socket)252 void OnWriteEvent(AsyncSocket* socket) {
253 ASSERT(socket == socket_);
254 SignalWriteEvent(this);
255 }
OnCloseEvent(AsyncSocket * socket,int error)256 void OnCloseEvent(AsyncSocket* socket, int error) {
257 ASSERT(socket == socket_);
258 SignalCloseEvent(this, error);
259 }
260
261 private:
262 // Makes sure the buffer is at least the given size.
Grow(size_t new_size)263 void Grow(size_t new_size) {
264 if (size_ < new_size) {
265 delete[] buf_;
266 size_ = new_size;
267 buf_ = new char[size_];
268 }
269 }
270
271 // Sends the destination address to the server to tell it to connect.
SendConnectRequest()272 void SendConnectRequest() {
273 char buf[kNATEncodedIPv6AddressSize];
274 size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
275 socket_->Send(buf, length);
276 }
277
278 // Handles the byte sent back from the server and fires the appropriate event.
HandleConnectReply()279 void HandleConnectReply() {
280 char code;
281 socket_->Recv(&code, sizeof(code));
282 if (code == 0) {
283 connected_ = true;
284 SignalConnectEvent(this);
285 } else {
286 Close();
287 SignalCloseEvent(this, code);
288 }
289 }
290
291 NATInternalSocketFactory* sf_;
292 int family_;
293 int type_;
294 bool connected_;
295 SocketAddress remote_addr_;
296 SocketAddress server_addr_; // address of the NAT server
297 AsyncSocket* socket_;
298 char* buf_;
299 size_t size_;
300 };
301
302 // NATSocketFactory
NATSocketFactory(SocketFactory * factory,const SocketAddress & nat_udp_addr,const SocketAddress & nat_tcp_addr)303 NATSocketFactory::NATSocketFactory(SocketFactory* factory,
304 const SocketAddress& nat_udp_addr,
305 const SocketAddress& nat_tcp_addr)
306 : factory_(factory), nat_udp_addr_(nat_udp_addr),
307 nat_tcp_addr_(nat_tcp_addr) {
308 }
309
CreateSocket(int type)310 Socket* NATSocketFactory::CreateSocket(int type) {
311 return CreateSocket(AF_INET, type);
312 }
313
CreateSocket(int family,int type)314 Socket* NATSocketFactory::CreateSocket(int family, int type) {
315 return new NATSocket(this, family, type);
316 }
317
CreateAsyncSocket(int type)318 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
319 return CreateAsyncSocket(AF_INET, type);
320 }
321
CreateAsyncSocket(int family,int type)322 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
323 return new NATSocket(this, family, type);
324 }
325
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)326 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
327 const SocketAddress& local_addr, SocketAddress* nat_addr) {
328 if (type == SOCK_STREAM) {
329 *nat_addr = nat_tcp_addr_;
330 } else {
331 *nat_addr = nat_udp_addr_;
332 }
333 return factory_->CreateAsyncSocket(family, type);
334 }
335
336 // NATSocketServer
NATSocketServer(SocketServer * server)337 NATSocketServer::NATSocketServer(SocketServer* server)
338 : server_(server), msg_queue_(NULL) {
339 }
340
GetTranslator(const SocketAddress & ext_ip)341 NATSocketServer::Translator* NATSocketServer::GetTranslator(
342 const SocketAddress& ext_ip) {
343 return nats_.Get(ext_ip);
344 }
345
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)346 NATSocketServer::Translator* NATSocketServer::AddTranslator(
347 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
348 // Fail if a translator already exists with this extternal address.
349 if (nats_.Get(ext_ip))
350 return NULL;
351
352 return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
353 }
354
RemoveTranslator(const SocketAddress & ext_ip)355 void NATSocketServer::RemoveTranslator(
356 const SocketAddress& ext_ip) {
357 nats_.Remove(ext_ip);
358 }
359
CreateSocket(int type)360 Socket* NATSocketServer::CreateSocket(int type) {
361 return CreateSocket(AF_INET, type);
362 }
363
CreateSocket(int family,int type)364 Socket* NATSocketServer::CreateSocket(int family, int type) {
365 return new NATSocket(this, family, type);
366 }
367
CreateAsyncSocket(int type)368 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
369 return CreateAsyncSocket(AF_INET, type);
370 }
371
CreateAsyncSocket(int family,int type)372 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
373 return new NATSocket(this, family, type);
374 }
375
SetMessageQueue(MessageQueue * queue)376 void NATSocketServer::SetMessageQueue(MessageQueue* queue) {
377 msg_queue_ = queue;
378 server_->SetMessageQueue(queue);
379 }
380
Wait(int cms,bool process_io)381 bool NATSocketServer::Wait(int cms, bool process_io) {
382 return server_->Wait(cms, process_io);
383 }
384
WakeUp()385 void NATSocketServer::WakeUp() {
386 server_->WakeUp();
387 }
388
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)389 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
390 const SocketAddress& local_addr, SocketAddress* nat_addr) {
391 AsyncSocket* socket = NULL;
392 Translator* nat = nats_.FindClient(local_addr);
393 if (nat) {
394 socket = nat->internal_factory()->CreateAsyncSocket(family, type);
395 *nat_addr = (type == SOCK_STREAM) ?
396 nat->internal_tcp_address() : nat->internal_udp_address();
397 } else {
398 socket = server_->CreateAsyncSocket(family, type);
399 }
400 return socket;
401 }
402
403 // NATSocketServer::Translator
Translator(NATSocketServer * server,NATType type,const SocketAddress & int_ip,SocketFactory * ext_factory,const SocketAddress & ext_ip)404 NATSocketServer::Translator::Translator(
405 NATSocketServer* server, NATType type, const SocketAddress& int_ip,
406 SocketFactory* ext_factory, const SocketAddress& ext_ip)
407 : server_(server) {
408 // Create a new private network, and a NATServer running on the private
409 // network that bridges to the external network. Also tell the private
410 // network to use the same message queue as us.
411 VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
412 internal_server->SetMessageQueue(server_->queue());
413 internal_factory_.reset(internal_server);
414 nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip,
415 ext_factory, ext_ip));
416 }
417
418 NATSocketServer::Translator::~Translator() = default;
419
GetTranslator(const SocketAddress & ext_ip)420 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
421 const SocketAddress& ext_ip) {
422 return nats_.Get(ext_ip);
423 }
424
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)425 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
426 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
427 // Fail if a translator already exists with this extternal address.
428 if (nats_.Get(ext_ip))
429 return NULL;
430
431 AddClient(ext_ip);
432 return nats_.Add(ext_ip,
433 new Translator(server_, type, int_ip, server_, ext_ip));
434 }
RemoveTranslator(const SocketAddress & ext_ip)435 void NATSocketServer::Translator::RemoveTranslator(
436 const SocketAddress& ext_ip) {
437 nats_.Remove(ext_ip);
438 RemoveClient(ext_ip);
439 }
440
AddClient(const SocketAddress & int_ip)441 bool NATSocketServer::Translator::AddClient(
442 const SocketAddress& int_ip) {
443 // Fail if a client already exists with this internal address.
444 if (clients_.find(int_ip) != clients_.end())
445 return false;
446
447 clients_.insert(int_ip);
448 return true;
449 }
450
RemoveClient(const SocketAddress & int_ip)451 void NATSocketServer::Translator::RemoveClient(
452 const SocketAddress& int_ip) {
453 std::set<SocketAddress>::iterator it = clients_.find(int_ip);
454 if (it != clients_.end()) {
455 clients_.erase(it);
456 }
457 }
458
FindClient(const SocketAddress & int_ip)459 NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
460 const SocketAddress& int_ip) {
461 // See if we have the requested IP, or any of our children do.
462 return (clients_.find(int_ip) != clients_.end()) ?
463 this : nats_.FindClient(int_ip);
464 }
465
466 // NATSocketServer::TranslatorMap
~TranslatorMap()467 NATSocketServer::TranslatorMap::~TranslatorMap() {
468 for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
469 delete it->second;
470 }
471 }
472
Get(const SocketAddress & ext_ip)473 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
474 const SocketAddress& ext_ip) {
475 TranslatorMap::iterator it = find(ext_ip);
476 return (it != end()) ? it->second : NULL;
477 }
478
Add(const SocketAddress & ext_ip,Translator * nat)479 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
480 const SocketAddress& ext_ip, Translator* nat) {
481 (*this)[ext_ip] = nat;
482 return nat;
483 }
484
Remove(const SocketAddress & ext_ip)485 void NATSocketServer::TranslatorMap::Remove(
486 const SocketAddress& ext_ip) {
487 TranslatorMap::iterator it = find(ext_ip);
488 if (it != end()) {
489 delete it->second;
490 erase(it);
491 }
492 }
493
FindClient(const SocketAddress & int_ip)494 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
495 const SocketAddress& int_ip) {
496 Translator* nat = NULL;
497 for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
498 nat = it->second->FindClient(int_ip);
499 }
500 return nat;
501 }
502
503 } // namespace rtc
504