• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/ssl_client_socket.h"
6 
7 #include <string>
8 
9 #include "base/containers/flat_tree.h"
10 #include "base/logging.h"
11 #include "base/observer_list.h"
12 #include "base/values.h"
13 #include "net/cert/x509_certificate_net_log_param.h"
14 #include "net/log/net_log.h"
15 #include "net/log/net_log_event_type.h"
16 #include "net/socket/ssl_client_socket_impl.h"
17 #include "net/socket/stream_socket.h"
18 #include "net/ssl/ssl_client_session_cache.h"
19 #include "net/ssl/ssl_key_logger.h"
20 
21 namespace net {
22 
23 namespace {
24 
25 // Returns true if |first_cert| and |second_cert| represent the same certificate
26 // (with the same chain), or if they're both NULL.
AreCertificatesEqual(const scoped_refptr<X509Certificate> & first_cert,const scoped_refptr<X509Certificate> & second_cert)27 bool AreCertificatesEqual(const scoped_refptr<X509Certificate>& first_cert,
28                           const scoped_refptr<X509Certificate>& second_cert) {
29   return (!first_cert && !second_cert) ||
30          (first_cert && second_cert &&
31           first_cert->EqualsIncludingChain(second_cert.get()));
32 }
33 
34 // Returns a base::Value::Dict value NetLog parameter with the expected format
35 // for events of type CLEAR_CACHED_CLIENT_CERT.
NetLogClearCachedClientCertParams(const net::HostPortPair & host,const scoped_refptr<net::X509Certificate> & cert,bool is_cleared)36 base::Value::Dict NetLogClearCachedClientCertParams(
37     const net::HostPortPair& host,
38     const scoped_refptr<net::X509Certificate>& cert,
39     bool is_cleared) {
40   base::Value::Dict dict;
41   dict.Set("host", host.ToString());
42   dict.Set("certificates", cert ? net::NetLogX509CertificateList(cert.get())
43                                 : base::Value(base::Value::List()));
44   dict.Set("is_cleared", is_cleared);
45   return dict;
46 }
47 
48 }  // namespace
49 
50 SSLClientSocket::SSLClientSocket() = default;
51 
52 // static
SetSSLKeyLogger(std::unique_ptr<SSLKeyLogger> logger)53 void SSLClientSocket::SetSSLKeyLogger(std::unique_ptr<SSLKeyLogger> logger) {
54   SSLClientSocketImpl::SetSSLKeyLogger(std::move(logger));
55 }
56 
57 // static
SerializeNextProtos(const NextProtoVector & next_protos)58 std::vector<uint8_t> SSLClientSocket::SerializeNextProtos(
59     const NextProtoVector& next_protos) {
60   std::vector<uint8_t> wire_protos;
61   for (const NextProto next_proto : next_protos) {
62     const std::string proto = NextProtoToString(next_proto);
63     if (proto.size() > 255) {
64       LOG(WARNING) << "Ignoring overlong ALPN protocol: " << proto;
65       continue;
66     }
67     if (proto.size() == 0) {
68       LOG(WARNING) << "Ignoring empty ALPN protocol";
69       continue;
70     }
71     wire_protos.push_back(proto.size());
72     for (const char ch : proto) {
73       wire_protos.push_back(static_cast<uint8_t>(ch));
74     }
75   }
76 
77   return wire_protos;
78 }
79 
SSLClientContext(SSLConfigService * ssl_config_service,CertVerifier * cert_verifier,TransportSecurityState * transport_security_state,CTPolicyEnforcer * ct_policy_enforcer,SSLClientSessionCache * ssl_client_session_cache,SCTAuditingDelegate * sct_auditing_delegate)80 SSLClientContext::SSLClientContext(
81     SSLConfigService* ssl_config_service,
82     CertVerifier* cert_verifier,
83     TransportSecurityState* transport_security_state,
84     CTPolicyEnforcer* ct_policy_enforcer,
85     SSLClientSessionCache* ssl_client_session_cache,
86     SCTAuditingDelegate* sct_auditing_delegate)
87     : ssl_config_service_(ssl_config_service),
88       cert_verifier_(cert_verifier),
89       transport_security_state_(transport_security_state),
90       ct_policy_enforcer_(ct_policy_enforcer),
91       ssl_client_session_cache_(ssl_client_session_cache),
92       sct_auditing_delegate_(sct_auditing_delegate) {
93   CHECK(cert_verifier_);
94   CHECK(transport_security_state_);
95   CHECK(ct_policy_enforcer_);
96 
97   if (ssl_config_service_) {
98     config_ = ssl_config_service_->GetSSLContextConfig();
99     ssl_config_service_->AddObserver(this);
100   }
101   cert_verifier_->AddObserver(this);
102   CertDatabase::GetInstance()->AddObserver(this);
103 }
104 
~SSLClientContext()105 SSLClientContext::~SSLClientContext() {
106   if (ssl_config_service_) {
107     ssl_config_service_->RemoveObserver(this);
108   }
109   cert_verifier_->RemoveObserver(this);
110   CertDatabase::GetInstance()->RemoveObserver(this);
111 }
112 
CreateSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)113 std::unique_ptr<SSLClientSocket> SSLClientContext::CreateSSLClientSocket(
114     std::unique_ptr<StreamSocket> stream_socket,
115     const HostPortPair& host_and_port,
116     const SSLConfig& ssl_config) {
117   return std::make_unique<SSLClientSocketImpl>(this, std::move(stream_socket),
118                                                host_and_port, ssl_config);
119 }
120 
GetClientCertificate(const HostPortPair & server,scoped_refptr<X509Certificate> * client_cert,scoped_refptr<SSLPrivateKey> * private_key)121 bool SSLClientContext::GetClientCertificate(
122     const HostPortPair& server,
123     scoped_refptr<X509Certificate>* client_cert,
124     scoped_refptr<SSLPrivateKey>* private_key) {
125   return ssl_client_auth_cache_.Lookup(server, client_cert, private_key);
126 }
127 
SetClientCertificate(const HostPortPair & server,scoped_refptr<X509Certificate> client_cert,scoped_refptr<SSLPrivateKey> private_key)128 void SSLClientContext::SetClientCertificate(
129     const HostPortPair& server,
130     scoped_refptr<X509Certificate> client_cert,
131     scoped_refptr<SSLPrivateKey> private_key) {
132   ssl_client_auth_cache_.Add(server, std::move(client_cert),
133                              std::move(private_key));
134 
135   if (ssl_client_session_cache_) {
136     // Session resumption bypasses client certificate negotiation, so flush all
137     // associated sessions when preferences change.
138     ssl_client_session_cache_->FlushForServers({server});
139   }
140   NotifySSLConfigForServersChanged({server});
141 }
142 
ClearClientCertificate(const HostPortPair & server)143 bool SSLClientContext::ClearClientCertificate(const HostPortPair& server) {
144   if (!ssl_client_auth_cache_.Remove(server)) {
145     return false;
146   }
147 
148   if (ssl_client_session_cache_) {
149     // Session resumption bypasses client certificate negotiation, so flush all
150     // associated sessions when preferences change.
151     ssl_client_session_cache_->FlushForServers({server});
152   }
153   NotifySSLConfigForServersChanged({server});
154   return true;
155 }
156 
AddObserver(Observer * observer)157 void SSLClientContext::AddObserver(Observer* observer) {
158   observers_.AddObserver(observer);
159 }
160 
RemoveObserver(Observer * observer)161 void SSLClientContext::RemoveObserver(Observer* observer) {
162   observers_.RemoveObserver(observer);
163 }
164 
OnSSLContextConfigChanged()165 void SSLClientContext::OnSSLContextConfigChanged() {
166   config_ = ssl_config_service_->GetSSLContextConfig();
167   if (ssl_client_session_cache_) {
168     ssl_client_session_cache_->Flush();
169   }
170   NotifySSLConfigChanged(SSLConfigChangeType::kSSLConfigChanged);
171 }
172 
OnCertVerifierChanged()173 void SSLClientContext::OnCertVerifierChanged() {
174   NotifySSLConfigChanged(SSLConfigChangeType::kCertVerifierChanged);
175 }
176 
OnTrustStoreChanged()177 void SSLClientContext::OnTrustStoreChanged() {
178   NotifySSLConfigChanged(SSLConfigChangeType::kCertDatabaseChanged);
179 }
180 
OnClientCertStoreChanged()181 void SSLClientContext::OnClientCertStoreChanged() {
182   base::flat_set<HostPortPair> servers =
183       ssl_client_auth_cache_.GetCachedServers();
184   ssl_client_auth_cache_.Clear();
185   if (ssl_client_session_cache_) {
186     ssl_client_session_cache_->FlushForServers(servers);
187   }
188   NotifySSLConfigForServersChanged(servers);
189 }
190 
ClearClientCertificateIfNeeded(const net::HostPortPair & host,const scoped_refptr<net::X509Certificate> & certificate)191 void SSLClientContext::ClearClientCertificateIfNeeded(
192     const net::HostPortPair& host,
193     const scoped_refptr<net::X509Certificate>& certificate) {
194   scoped_refptr<X509Certificate> cached_certificate;
195   scoped_refptr<SSLPrivateKey> cached_private_key;
196   if (!ssl_client_auth_cache_.Lookup(host, &cached_certificate,
197                                      &cached_private_key) ||
198       AreCertificatesEqual(cached_certificate, certificate)) {
199     // No cached client certificate preference for this host.
200     net::NetLog::Get()->AddGlobalEntry(
201         NetLogEventType::CLEAR_CACHED_CLIENT_CERT, [&]() {
202           return NetLogClearCachedClientCertParams(host, certificate,
203                                                    /*is_cleared=*/false);
204         });
205     return;
206   }
207 
208   net::NetLog::Get()->AddGlobalEntry(
209       NetLogEventType::CLEAR_CACHED_CLIENT_CERT, [&]() {
210         return NetLogClearCachedClientCertParams(host, certificate,
211                                                  /*is_cleared=*/true);
212       });
213 
214   ssl_client_auth_cache_.Remove(host);
215 
216   if (ssl_client_session_cache_) {
217     ssl_client_session_cache_->FlushForServers({host});
218   }
219 
220   NotifySSLConfigForServersChanged({host});
221 }
222 
NotifySSLConfigChanged(SSLConfigChangeType change_type)223 void SSLClientContext::NotifySSLConfigChanged(SSLConfigChangeType change_type) {
224   for (Observer& observer : observers_) {
225     observer.OnSSLConfigChanged(change_type);
226   }
227 }
228 
NotifySSLConfigForServersChanged(const base::flat_set<HostPortPair> & servers)229 void SSLClientContext::NotifySSLConfigForServersChanged(
230     const base::flat_set<HostPortPair>& servers) {
231   for (Observer& observer : observers_) {
232     observer.OnSSLConfigForServersChanged(servers);
233   }
234 }
235 
236 }  // namespace net
237