1 /*
2 * Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11
12 #include <vector>
13
14 #if HAVE_CONFIG_H
15 #include "config.h"
16 #endif // HAVE_CONFIG_H
17
18 #include "webrtc/base/sslstreamadapterhelper.h"
19
20 #include "webrtc/base/common.h"
21 #include "webrtc/base/logging.h"
22 #include "webrtc/base/stream.h"
23
24 namespace rtc {
25
SSLStreamAdapterHelper(StreamInterface * stream)26 SSLStreamAdapterHelper::SSLStreamAdapterHelper(StreamInterface* stream)
27 : SSLStreamAdapter(stream),
28 state_(SSL_NONE),
29 role_(SSL_CLIENT),
30 ssl_error_code_(0), // Not meaningful yet
31 ssl_mode_(SSL_MODE_TLS),
32 ssl_max_version_(SSL_PROTOCOL_TLS_12) {}
33
34 SSLStreamAdapterHelper::~SSLStreamAdapterHelper() = default;
35
SetIdentity(SSLIdentity * identity)36 void SSLStreamAdapterHelper::SetIdentity(SSLIdentity* identity) {
37 ASSERT(identity_.get() == NULL);
38 identity_.reset(identity);
39 }
40
SetServerRole(SSLRole role)41 void SSLStreamAdapterHelper::SetServerRole(SSLRole role) {
42 role_ = role;
43 }
44
StartSSLWithServer(const char * server_name)45 int SSLStreamAdapterHelper::StartSSLWithServer(const char* server_name) {
46 ASSERT(server_name != NULL && server_name[0] != '\0');
47 ssl_server_name_ = server_name;
48 return StartSSL();
49 }
50
StartSSLWithPeer()51 int SSLStreamAdapterHelper::StartSSLWithPeer() {
52 ASSERT(ssl_server_name_.empty());
53 // It is permitted to specify peer_certificate_ only later.
54 return StartSSL();
55 }
56
SetMode(SSLMode mode)57 void SSLStreamAdapterHelper::SetMode(SSLMode mode) {
58 ASSERT(state_ == SSL_NONE);
59 ssl_mode_ = mode;
60 }
61
SetMaxProtocolVersion(SSLProtocolVersion version)62 void SSLStreamAdapterHelper::SetMaxProtocolVersion(SSLProtocolVersion version) {
63 ssl_max_version_ = version;
64 }
65
GetState() const66 StreamState SSLStreamAdapterHelper::GetState() const {
67 switch (state_) {
68 case SSL_WAIT:
69 case SSL_CONNECTING:
70 return SS_OPENING;
71 case SSL_CONNECTED:
72 return SS_OPEN;
73 default:
74 return SS_CLOSED;
75 };
76 // not reached
77 }
78
GetPeerCertificate(SSLCertificate ** cert) const79 bool SSLStreamAdapterHelper::GetPeerCertificate(SSLCertificate** cert) const {
80 if (!peer_certificate_)
81 return false;
82
83 *cert = peer_certificate_->GetReference();
84 return true;
85 }
86
SetPeerCertificateDigest(const std::string & digest_alg,const unsigned char * digest_val,size_t digest_len)87 bool SSLStreamAdapterHelper::SetPeerCertificateDigest(
88 const std::string &digest_alg,
89 const unsigned char* digest_val,
90 size_t digest_len) {
91 ASSERT(peer_certificate_.get() == NULL);
92 ASSERT(peer_certificate_digest_algorithm_.empty());
93 ASSERT(ssl_server_name_.empty());
94 size_t expected_len;
95
96 if (!GetDigestLength(digest_alg, &expected_len)) {
97 LOG(LS_WARNING) << "Unknown digest algorithm: " << digest_alg;
98 return false;
99 }
100 if (expected_len != digest_len)
101 return false;
102
103 peer_certificate_digest_value_.SetData(digest_val, digest_len);
104 peer_certificate_digest_algorithm_ = digest_alg;
105
106 return true;
107 }
108
Error(const char * context,int err,bool signal)109 void SSLStreamAdapterHelper::Error(const char* context, int err, bool signal) {
110 LOG(LS_WARNING) << "SSLStreamAdapterHelper::Error("
111 << context << ", " << err << "," << signal << ")";
112 state_ = SSL_ERROR;
113 ssl_error_code_ = err;
114 Cleanup();
115 if (signal)
116 StreamAdapterInterface::OnEvent(stream(), SE_CLOSE, err);
117 }
118
Close()119 void SSLStreamAdapterHelper::Close() {
120 Cleanup();
121 ASSERT(state_ == SSL_CLOSED || state_ == SSL_ERROR);
122 StreamAdapterInterface::Close();
123 }
124
StartSSL()125 int SSLStreamAdapterHelper::StartSSL() {
126 ASSERT(state_ == SSL_NONE);
127
128 if (StreamAdapterInterface::GetState() != SS_OPEN) {
129 state_ = SSL_WAIT;
130 return 0;
131 }
132
133 state_ = SSL_CONNECTING;
134 int err = BeginSSL();
135 if (err) {
136 Error("BeginSSL", err, false);
137 return err;
138 }
139
140 return 0;
141 }
142
143 } // namespace rtc
144
145