1 /*
2 * Copyright 2015 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "webrtc/p2p/base/transportcontroller.h"
12
13 #include <algorithm>
14
15 #include "webrtc/base/bind.h"
16 #include "webrtc/base/checks.h"
17 #include "webrtc/base/thread.h"
18 #include "webrtc/p2p/base/dtlstransport.h"
19 #include "webrtc/p2p/base/p2ptransport.h"
20 #include "webrtc/p2p/base/port.h"
21
22 namespace cricket {
23
24 enum {
25 MSG_ICECONNECTIONSTATE,
26 MSG_RECEIVING,
27 MSG_ICEGATHERINGSTATE,
28 MSG_CANDIDATESGATHERED,
29 };
30
31 struct CandidatesData : public rtc::MessageData {
CandidatesDatacricket::CandidatesData32 CandidatesData(const std::string& transport_name,
33 const Candidates& candidates)
34 : transport_name(transport_name), candidates(candidates) {}
35
36 std::string transport_name;
37 Candidates candidates;
38 };
39
TransportController(rtc::Thread * signaling_thread,rtc::Thread * worker_thread,PortAllocator * port_allocator)40 TransportController::TransportController(rtc::Thread* signaling_thread,
41 rtc::Thread* worker_thread,
42 PortAllocator* port_allocator)
43 : signaling_thread_(signaling_thread),
44 worker_thread_(worker_thread),
45 port_allocator_(port_allocator) {}
46
~TransportController()47 TransportController::~TransportController() {
48 worker_thread_->Invoke<void>(
49 rtc::Bind(&TransportController::DestroyAllTransports_w, this));
50 signaling_thread_->Clear(this);
51 }
52
SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version)53 bool TransportController::SetSslMaxProtocolVersion(
54 rtc::SSLProtocolVersion version) {
55 return worker_thread_->Invoke<bool>(rtc::Bind(
56 &TransportController::SetSslMaxProtocolVersion_w, this, version));
57 }
58
SetIceConfig(const IceConfig & config)59 void TransportController::SetIceConfig(const IceConfig& config) {
60 worker_thread_->Invoke<void>(
61 rtc::Bind(&TransportController::SetIceConfig_w, this, config));
62 }
63
SetIceRole(IceRole ice_role)64 void TransportController::SetIceRole(IceRole ice_role) {
65 worker_thread_->Invoke<void>(
66 rtc::Bind(&TransportController::SetIceRole_w, this, ice_role));
67 }
68
GetSslRole(const std::string & transport_name,rtc::SSLRole * role)69 bool TransportController::GetSslRole(const std::string& transport_name,
70 rtc::SSLRole* role) {
71 return worker_thread_->Invoke<bool>(rtc::Bind(
72 &TransportController::GetSslRole_w, this, transport_name, role));
73 }
74
SetLocalCertificate(const rtc::scoped_refptr<rtc::RTCCertificate> & certificate)75 bool TransportController::SetLocalCertificate(
76 const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
77 return worker_thread_->Invoke<bool>(rtc::Bind(
78 &TransportController::SetLocalCertificate_w, this, certificate));
79 }
80
GetLocalCertificate(const std::string & transport_name,rtc::scoped_refptr<rtc::RTCCertificate> * certificate)81 bool TransportController::GetLocalCertificate(
82 const std::string& transport_name,
83 rtc::scoped_refptr<rtc::RTCCertificate>* certificate) {
84 return worker_thread_->Invoke<bool>(
85 rtc::Bind(&TransportController::GetLocalCertificate_w, this,
86 transport_name, certificate));
87 }
88
GetRemoteSSLCertificate(const std::string & transport_name,rtc::SSLCertificate ** cert)89 bool TransportController::GetRemoteSSLCertificate(
90 const std::string& transport_name,
91 rtc::SSLCertificate** cert) {
92 return worker_thread_->Invoke<bool>(
93 rtc::Bind(&TransportController::GetRemoteSSLCertificate_w, this,
94 transport_name, cert));
95 }
96
SetLocalTransportDescription(const std::string & transport_name,const TransportDescription & tdesc,ContentAction action,std::string * err)97 bool TransportController::SetLocalTransportDescription(
98 const std::string& transport_name,
99 const TransportDescription& tdesc,
100 ContentAction action,
101 std::string* err) {
102 return worker_thread_->Invoke<bool>(
103 rtc::Bind(&TransportController::SetLocalTransportDescription_w, this,
104 transport_name, tdesc, action, err));
105 }
106
SetRemoteTransportDescription(const std::string & transport_name,const TransportDescription & tdesc,ContentAction action,std::string * err)107 bool TransportController::SetRemoteTransportDescription(
108 const std::string& transport_name,
109 const TransportDescription& tdesc,
110 ContentAction action,
111 std::string* err) {
112 return worker_thread_->Invoke<bool>(
113 rtc::Bind(&TransportController::SetRemoteTransportDescription_w, this,
114 transport_name, tdesc, action, err));
115 }
116
MaybeStartGathering()117 void TransportController::MaybeStartGathering() {
118 worker_thread_->Invoke<void>(
119 rtc::Bind(&TransportController::MaybeStartGathering_w, this));
120 }
121
AddRemoteCandidates(const std::string & transport_name,const Candidates & candidates,std::string * err)122 bool TransportController::AddRemoteCandidates(const std::string& transport_name,
123 const Candidates& candidates,
124 std::string* err) {
125 return worker_thread_->Invoke<bool>(
126 rtc::Bind(&TransportController::AddRemoteCandidates_w, this,
127 transport_name, candidates, err));
128 }
129
ReadyForRemoteCandidates(const std::string & transport_name)130 bool TransportController::ReadyForRemoteCandidates(
131 const std::string& transport_name) {
132 return worker_thread_->Invoke<bool>(rtc::Bind(
133 &TransportController::ReadyForRemoteCandidates_w, this, transport_name));
134 }
135
GetStats(const std::string & transport_name,TransportStats * stats)136 bool TransportController::GetStats(const std::string& transport_name,
137 TransportStats* stats) {
138 return worker_thread_->Invoke<bool>(
139 rtc::Bind(&TransportController::GetStats_w, this, transport_name, stats));
140 }
141
CreateTransportChannel_w(const std::string & transport_name,int component)142 TransportChannel* TransportController::CreateTransportChannel_w(
143 const std::string& transport_name,
144 int component) {
145 RTC_DCHECK(worker_thread_->IsCurrent());
146
147 auto it = FindChannel_w(transport_name, component);
148 if (it != channels_.end()) {
149 // Channel already exists; increment reference count and return.
150 it->AddRef();
151 return it->get();
152 }
153
154 // Need to create a new channel.
155 Transport* transport = GetOrCreateTransport_w(transport_name);
156 TransportChannelImpl* channel = transport->CreateChannel(component);
157 channel->SignalWritableState.connect(
158 this, &TransportController::OnChannelWritableState_w);
159 channel->SignalReceivingState.connect(
160 this, &TransportController::OnChannelReceivingState_w);
161 channel->SignalGatheringState.connect(
162 this, &TransportController::OnChannelGatheringState_w);
163 channel->SignalCandidateGathered.connect(
164 this, &TransportController::OnChannelCandidateGathered_w);
165 channel->SignalRoleConflict.connect(
166 this, &TransportController::OnChannelRoleConflict_w);
167 channel->SignalConnectionRemoved.connect(
168 this, &TransportController::OnChannelConnectionRemoved_w);
169 channels_.insert(channels_.end(), RefCountedChannel(channel))->AddRef();
170 // Adding a channel could cause aggregate state to change.
171 UpdateAggregateStates_w();
172 return channel;
173 }
174
DestroyTransportChannel_w(const std::string & transport_name,int component)175 void TransportController::DestroyTransportChannel_w(
176 const std::string& transport_name,
177 int component) {
178 RTC_DCHECK(worker_thread_->IsCurrent());
179
180 auto it = FindChannel_w(transport_name, component);
181 if (it == channels_.end()) {
182 LOG(LS_WARNING) << "Attempting to delete " << transport_name
183 << " TransportChannel " << component
184 << ", which doesn't exist.";
185 return;
186 }
187
188 it->DecRef();
189 if (it->ref() > 0) {
190 return;
191 }
192
193 channels_.erase(it);
194 Transport* transport = GetTransport_w(transport_name);
195 transport->DestroyChannel(component);
196 // Just as we create a Transport when its first channel is created,
197 // we delete it when its last channel is deleted.
198 if (!transport->HasChannels()) {
199 DestroyTransport_w(transport_name);
200 }
201 // Removing a channel could cause aggregate state to change.
202 UpdateAggregateStates_w();
203 }
204
205 const rtc::scoped_refptr<rtc::RTCCertificate>&
certificate_for_testing()206 TransportController::certificate_for_testing() {
207 return certificate_;
208 }
209
CreateTransport_w(const std::string & transport_name)210 Transport* TransportController::CreateTransport_w(
211 const std::string& transport_name) {
212 RTC_DCHECK(worker_thread_->IsCurrent());
213
214 Transport* transport = new DtlsTransport<P2PTransport>(
215 transport_name, port_allocator(), certificate_);
216 return transport;
217 }
218
GetTransport_w(const std::string & transport_name)219 Transport* TransportController::GetTransport_w(
220 const std::string& transport_name) {
221 RTC_DCHECK(worker_thread_->IsCurrent());
222
223 auto iter = transports_.find(transport_name);
224 return (iter != transports_.end()) ? iter->second : nullptr;
225 }
226
OnMessage(rtc::Message * pmsg)227 void TransportController::OnMessage(rtc::Message* pmsg) {
228 RTC_DCHECK(signaling_thread_->IsCurrent());
229
230 switch (pmsg->message_id) {
231 case MSG_ICECONNECTIONSTATE: {
232 rtc::TypedMessageData<IceConnectionState>* data =
233 static_cast<rtc::TypedMessageData<IceConnectionState>*>(pmsg->pdata);
234 SignalConnectionState(data->data());
235 delete data;
236 break;
237 }
238 case MSG_RECEIVING: {
239 rtc::TypedMessageData<bool>* data =
240 static_cast<rtc::TypedMessageData<bool>*>(pmsg->pdata);
241 SignalReceiving(data->data());
242 delete data;
243 break;
244 }
245 case MSG_ICEGATHERINGSTATE: {
246 rtc::TypedMessageData<IceGatheringState>* data =
247 static_cast<rtc::TypedMessageData<IceGatheringState>*>(pmsg->pdata);
248 SignalGatheringState(data->data());
249 delete data;
250 break;
251 }
252 case MSG_CANDIDATESGATHERED: {
253 CandidatesData* data = static_cast<CandidatesData*>(pmsg->pdata);
254 SignalCandidatesGathered(data->transport_name, data->candidates);
255 delete data;
256 break;
257 }
258 default:
259 ASSERT(false);
260 }
261 }
262
263 std::vector<TransportController::RefCountedChannel>::iterator
FindChannel_w(const std::string & transport_name,int component)264 TransportController::FindChannel_w(const std::string& transport_name,
265 int component) {
266 return std::find_if(
267 channels_.begin(), channels_.end(),
268 [transport_name, component](const RefCountedChannel& channel) {
269 return channel->transport_name() == transport_name &&
270 channel->component() == component;
271 });
272 }
273
GetOrCreateTransport_w(const std::string & transport_name)274 Transport* TransportController::GetOrCreateTransport_w(
275 const std::string& transport_name) {
276 RTC_DCHECK(worker_thread_->IsCurrent());
277
278 Transport* transport = GetTransport_w(transport_name);
279 if (transport) {
280 return transport;
281 }
282
283 transport = CreateTransport_w(transport_name);
284 // The stuff below happens outside of CreateTransport_w so that unit tests
285 // can override CreateTransport_w to return a different type of transport.
286 transport->SetSslMaxProtocolVersion(ssl_max_version_);
287 transport->SetIceConfig(ice_config_);
288 transport->SetIceRole(ice_role_);
289 transport->SetIceTiebreaker(ice_tiebreaker_);
290 if (certificate_) {
291 transport->SetLocalCertificate(certificate_);
292 }
293 transports_[transport_name] = transport;
294
295 return transport;
296 }
297
DestroyTransport_w(const std::string & transport_name)298 void TransportController::DestroyTransport_w(
299 const std::string& transport_name) {
300 RTC_DCHECK(worker_thread_->IsCurrent());
301
302 auto iter = transports_.find(transport_name);
303 if (iter != transports_.end()) {
304 delete iter->second;
305 transports_.erase(transport_name);
306 }
307 }
308
DestroyAllTransports_w()309 void TransportController::DestroyAllTransports_w() {
310 RTC_DCHECK(worker_thread_->IsCurrent());
311
312 for (const auto& kv : transports_) {
313 delete kv.second;
314 }
315 transports_.clear();
316 }
317
SetSslMaxProtocolVersion_w(rtc::SSLProtocolVersion version)318 bool TransportController::SetSslMaxProtocolVersion_w(
319 rtc::SSLProtocolVersion version) {
320 RTC_DCHECK(worker_thread_->IsCurrent());
321
322 // Max SSL version can only be set before transports are created.
323 if (!transports_.empty()) {
324 return false;
325 }
326
327 ssl_max_version_ = version;
328 return true;
329 }
330
SetIceConfig_w(const IceConfig & config)331 void TransportController::SetIceConfig_w(const IceConfig& config) {
332 RTC_DCHECK(worker_thread_->IsCurrent());
333 ice_config_ = config;
334 for (const auto& kv : transports_) {
335 kv.second->SetIceConfig(ice_config_);
336 }
337 }
338
SetIceRole_w(IceRole ice_role)339 void TransportController::SetIceRole_w(IceRole ice_role) {
340 RTC_DCHECK(worker_thread_->IsCurrent());
341 ice_role_ = ice_role;
342 for (const auto& kv : transports_) {
343 kv.second->SetIceRole(ice_role_);
344 }
345 }
346
GetSslRole_w(const std::string & transport_name,rtc::SSLRole * role)347 bool TransportController::GetSslRole_w(const std::string& transport_name,
348 rtc::SSLRole* role) {
349 RTC_DCHECK(worker_thread()->IsCurrent());
350
351 Transport* t = GetTransport_w(transport_name);
352 if (!t) {
353 return false;
354 }
355
356 return t->GetSslRole(role);
357 }
358
SetLocalCertificate_w(const rtc::scoped_refptr<rtc::RTCCertificate> & certificate)359 bool TransportController::SetLocalCertificate_w(
360 const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
361 RTC_DCHECK(worker_thread_->IsCurrent());
362
363 if (certificate_) {
364 return false;
365 }
366 if (!certificate) {
367 return false;
368 }
369 certificate_ = certificate;
370
371 for (const auto& kv : transports_) {
372 kv.second->SetLocalCertificate(certificate_);
373 }
374 return true;
375 }
376
GetLocalCertificate_w(const std::string & transport_name,rtc::scoped_refptr<rtc::RTCCertificate> * certificate)377 bool TransportController::GetLocalCertificate_w(
378 const std::string& transport_name,
379 rtc::scoped_refptr<rtc::RTCCertificate>* certificate) {
380 RTC_DCHECK(worker_thread_->IsCurrent());
381
382 Transport* t = GetTransport_w(transport_name);
383 if (!t) {
384 return false;
385 }
386
387 return t->GetLocalCertificate(certificate);
388 }
389
GetRemoteSSLCertificate_w(const std::string & transport_name,rtc::SSLCertificate ** cert)390 bool TransportController::GetRemoteSSLCertificate_w(
391 const std::string& transport_name,
392 rtc::SSLCertificate** cert) {
393 RTC_DCHECK(worker_thread_->IsCurrent());
394
395 Transport* t = GetTransport_w(transport_name);
396 if (!t) {
397 return false;
398 }
399
400 return t->GetRemoteSSLCertificate(cert);
401 }
402
SetLocalTransportDescription_w(const std::string & transport_name,const TransportDescription & tdesc,ContentAction action,std::string * err)403 bool TransportController::SetLocalTransportDescription_w(
404 const std::string& transport_name,
405 const TransportDescription& tdesc,
406 ContentAction action,
407 std::string* err) {
408 RTC_DCHECK(worker_thread()->IsCurrent());
409
410 Transport* transport = GetTransport_w(transport_name);
411 if (!transport) {
412 // If we didn't find a transport, that's not an error;
413 // it could have been deleted as a result of bundling.
414 // TODO(deadbeef): Make callers smarter so they won't attempt to set a
415 // description on a deleted transport.
416 return true;
417 }
418
419 return transport->SetLocalTransportDescription(tdesc, action, err);
420 }
421
SetRemoteTransportDescription_w(const std::string & transport_name,const TransportDescription & tdesc,ContentAction action,std::string * err)422 bool TransportController::SetRemoteTransportDescription_w(
423 const std::string& transport_name,
424 const TransportDescription& tdesc,
425 ContentAction action,
426 std::string* err) {
427 RTC_DCHECK(worker_thread()->IsCurrent());
428
429 Transport* transport = GetTransport_w(transport_name);
430 if (!transport) {
431 // If we didn't find a transport, that's not an error;
432 // it could have been deleted as a result of bundling.
433 // TODO(deadbeef): Make callers smarter so they won't attempt to set a
434 // description on a deleted transport.
435 return true;
436 }
437
438 return transport->SetRemoteTransportDescription(tdesc, action, err);
439 }
440
MaybeStartGathering_w()441 void TransportController::MaybeStartGathering_w() {
442 for (const auto& kv : transports_) {
443 kv.second->MaybeStartGathering();
444 }
445 }
446
AddRemoteCandidates_w(const std::string & transport_name,const Candidates & candidates,std::string * err)447 bool TransportController::AddRemoteCandidates_w(
448 const std::string& transport_name,
449 const Candidates& candidates,
450 std::string* err) {
451 RTC_DCHECK(worker_thread()->IsCurrent());
452
453 Transport* transport = GetTransport_w(transport_name);
454 if (!transport) {
455 // If we didn't find a transport, that's not an error;
456 // it could have been deleted as a result of bundling.
457 return true;
458 }
459
460 return transport->AddRemoteCandidates(candidates, err);
461 }
462
ReadyForRemoteCandidates_w(const std::string & transport_name)463 bool TransportController::ReadyForRemoteCandidates_w(
464 const std::string& transport_name) {
465 RTC_DCHECK(worker_thread()->IsCurrent());
466
467 Transport* transport = GetTransport_w(transport_name);
468 if (!transport) {
469 return false;
470 }
471 return transport->ready_for_remote_candidates();
472 }
473
GetStats_w(const std::string & transport_name,TransportStats * stats)474 bool TransportController::GetStats_w(const std::string& transport_name,
475 TransportStats* stats) {
476 RTC_DCHECK(worker_thread()->IsCurrent());
477
478 Transport* transport = GetTransport_w(transport_name);
479 if (!transport) {
480 return false;
481 }
482 return transport->GetStats(stats);
483 }
484
OnChannelWritableState_w(TransportChannel * channel)485 void TransportController::OnChannelWritableState_w(TransportChannel* channel) {
486 RTC_DCHECK(worker_thread_->IsCurrent());
487 LOG(LS_INFO) << channel->transport_name() << " TransportChannel "
488 << channel->component() << " writability changed to "
489 << channel->writable() << ".";
490 UpdateAggregateStates_w();
491 }
492
OnChannelReceivingState_w(TransportChannel * channel)493 void TransportController::OnChannelReceivingState_w(TransportChannel* channel) {
494 RTC_DCHECK(worker_thread_->IsCurrent());
495 UpdateAggregateStates_w();
496 }
497
OnChannelGatheringState_w(TransportChannelImpl * channel)498 void TransportController::OnChannelGatheringState_w(
499 TransportChannelImpl* channel) {
500 RTC_DCHECK(worker_thread_->IsCurrent());
501 UpdateAggregateStates_w();
502 }
503
OnChannelCandidateGathered_w(TransportChannelImpl * channel,const Candidate & candidate)504 void TransportController::OnChannelCandidateGathered_w(
505 TransportChannelImpl* channel,
506 const Candidate& candidate) {
507 RTC_DCHECK(worker_thread_->IsCurrent());
508
509 // We should never signal peer-reflexive candidates.
510 if (candidate.type() == PRFLX_PORT_TYPE) {
511 RTC_DCHECK(false);
512 return;
513 }
514 std::vector<Candidate> candidates;
515 candidates.push_back(candidate);
516 CandidatesData* data =
517 new CandidatesData(channel->transport_name(), candidates);
518 signaling_thread_->Post(this, MSG_CANDIDATESGATHERED, data);
519 }
520
OnChannelRoleConflict_w(TransportChannelImpl * channel)521 void TransportController::OnChannelRoleConflict_w(
522 TransportChannelImpl* channel) {
523 RTC_DCHECK(worker_thread_->IsCurrent());
524
525 if (ice_role_switch_) {
526 LOG(LS_WARNING)
527 << "Repeat of role conflict signal from TransportChannelImpl.";
528 return;
529 }
530
531 ice_role_switch_ = true;
532 IceRole reversed_role = (ice_role_ == ICEROLE_CONTROLLING)
533 ? ICEROLE_CONTROLLED
534 : ICEROLE_CONTROLLING;
535 for (const auto& kv : transports_) {
536 kv.second->SetIceRole(reversed_role);
537 }
538 }
539
OnChannelConnectionRemoved_w(TransportChannelImpl * channel)540 void TransportController::OnChannelConnectionRemoved_w(
541 TransportChannelImpl* channel) {
542 RTC_DCHECK(worker_thread_->IsCurrent());
543 LOG(LS_INFO) << channel->transport_name() << " TransportChannel "
544 << channel->component()
545 << " connection removed. Check if state is complete.";
546 UpdateAggregateStates_w();
547 }
548
UpdateAggregateStates_w()549 void TransportController::UpdateAggregateStates_w() {
550 RTC_DCHECK(worker_thread_->IsCurrent());
551
552 IceConnectionState new_connection_state = kIceConnectionConnecting;
553 IceGatheringState new_gathering_state = kIceGatheringNew;
554 bool any_receiving = false;
555 bool any_failed = false;
556 bool all_connected = !channels_.empty();
557 bool all_completed = !channels_.empty();
558 bool any_gathering = false;
559 bool all_done_gathering = !channels_.empty();
560 for (const auto& channel : channels_) {
561 any_receiving = any_receiving || channel->receiving();
562 any_failed = any_failed ||
563 channel->GetState() == TransportChannelState::STATE_FAILED;
564 all_connected = all_connected && channel->writable();
565 all_completed =
566 all_completed && channel->writable() &&
567 channel->GetState() == TransportChannelState::STATE_COMPLETED &&
568 channel->GetIceRole() == ICEROLE_CONTROLLING &&
569 channel->gathering_state() == kIceGatheringComplete;
570 any_gathering =
571 any_gathering || channel->gathering_state() != kIceGatheringNew;
572 all_done_gathering = all_done_gathering &&
573 channel->gathering_state() == kIceGatheringComplete;
574 }
575
576 if (any_failed) {
577 new_connection_state = kIceConnectionFailed;
578 } else if (all_completed) {
579 new_connection_state = kIceConnectionCompleted;
580 } else if (all_connected) {
581 new_connection_state = kIceConnectionConnected;
582 }
583 if (connection_state_ != new_connection_state) {
584 connection_state_ = new_connection_state;
585 signaling_thread_->Post(
586 this, MSG_ICECONNECTIONSTATE,
587 new rtc::TypedMessageData<IceConnectionState>(new_connection_state));
588 }
589
590 if (receiving_ != any_receiving) {
591 receiving_ = any_receiving;
592 signaling_thread_->Post(this, MSG_RECEIVING,
593 new rtc::TypedMessageData<bool>(any_receiving));
594 }
595
596 if (all_done_gathering) {
597 new_gathering_state = kIceGatheringComplete;
598 } else if (any_gathering) {
599 new_gathering_state = kIceGatheringGathering;
600 }
601 if (gathering_state_ != new_gathering_state) {
602 gathering_state_ = new_gathering_state;
603 signaling_thread_->Post(
604 this, MSG_ICEGATHERINGSTATE,
605 new rtc::TypedMessageData<IceGatheringState>(new_gathering_state));
606 }
607 }
608
609 } // namespace cricket
610