1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/cert/mock_cert_verifier.h"
6
7 #include <memory>
8 #include <utility>
9
10 #include "base/callback_list.h"
11 #include "base/functional/bind.h"
12 #include "base/location.h"
13 #include "base/memory/raw_ptr.h"
14 #include "base/memory/ref_counted.h"
15 #include "base/memory/weak_ptr.h"
16 #include "base/strings/pattern.h"
17 #include "base/strings/string_util.h"
18 #include "base/task/single_thread_task_runner.h"
19 #include "net/base/net_errors.h"
20 #include "net/cert/cert_status_flags.h"
21 #include "net/cert/cert_verify_result.h"
22 #include "net/cert/x509_certificate.h"
23
24 namespace net {
25
26 struct MockCertVerifier::Rule {
Rulenet::MockCertVerifier::Rule27 Rule(scoped_refptr<X509Certificate> cert_arg,
28 const std::string& hostname_arg,
29 const CertVerifyResult& result_arg,
30 int rv_arg)
31 : cert(std::move(cert_arg)),
32 hostname(hostname_arg),
33 result(result_arg),
34 rv(rv_arg) {
35 DCHECK(cert);
36 DCHECK(result.verified_cert);
37 }
38
39 scoped_refptr<X509Certificate> cert;
40 std::string hostname;
41 CertVerifyResult result;
42 int rv;
43 };
44
45 class MockCertVerifier::MockRequest : public CertVerifier::Request {
46 public:
MockRequest(MockCertVerifier * parent,CertVerifyResult * result,CompletionOnceCallback callback)47 MockRequest(MockCertVerifier* parent,
48 CertVerifyResult* result,
49 CompletionOnceCallback callback)
50 : result_(result), callback_(std::move(callback)) {
51 subscription_ = parent->request_list_.Add(
52 base::BindOnce(&MockRequest::Cleanup, weak_factory_.GetWeakPtr()));
53 }
54
ReturnResultLater(int rv,const CertVerifyResult & result)55 void ReturnResultLater(int rv, const CertVerifyResult& result) {
56 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
57 FROM_HERE, base::BindOnce(&MockRequest::ReturnResult,
58 weak_factory_.GetWeakPtr(), rv, result));
59 }
60
61 private:
ReturnResult(int rv,const CertVerifyResult & result)62 void ReturnResult(int rv, const CertVerifyResult& result) {
63 // If the MockCertVerifier has been deleted, the callback will have been
64 // reset to null.
65 if (!callback_)
66 return;
67
68 *result_ = result;
69 std::move(callback_).Run(rv);
70 }
71
Cleanup()72 void Cleanup() {
73 // Note: May delete |this_|.
74 std::move(callback_).Reset();
75 }
76
77 raw_ptr<CertVerifyResult> result_;
78 CompletionOnceCallback callback_;
79 base::CallbackListSubscription subscription_;
80
81 base::WeakPtrFactory<MockRequest> weak_factory_{this};
82 };
83
84 MockCertVerifier::MockCertVerifier() = default;
85
~MockCertVerifier()86 MockCertVerifier::~MockCertVerifier() {
87 // Reset the callbacks for any outstanding MockRequests to fulfill the
88 // respective net::CertVerifier contract.
89 request_list_.Notify();
90 }
91
Verify(const RequestParams & params,CertVerifyResult * verify_result,CompletionOnceCallback callback,std::unique_ptr<Request> * out_req,const NetLogWithSource & net_log)92 int MockCertVerifier::Verify(const RequestParams& params,
93 CertVerifyResult* verify_result,
94 CompletionOnceCallback callback,
95 std::unique_ptr<Request>* out_req,
96 const NetLogWithSource& net_log) {
97 if (!async_) {
98 return VerifyImpl(params, verify_result);
99 }
100
101 auto request =
102 std::make_unique<MockRequest>(this, verify_result, std::move(callback));
103 CertVerifyResult result;
104 int rv = VerifyImpl(params, &result);
105 request->ReturnResultLater(rv, result);
106 *out_req = std::move(request);
107 return ERR_IO_PENDING;
108 }
109
AddObserver(Observer * observer)110 void MockCertVerifier::AddObserver(Observer* observer) {
111 observers_.AddObserver(observer);
112 }
113
RemoveObserver(Observer * observer)114 void MockCertVerifier::RemoveObserver(Observer* observer) {
115 observers_.RemoveObserver(observer);
116 }
117
AddResultForCert(scoped_refptr<X509Certificate> cert,const CertVerifyResult & verify_result,int rv)118 void MockCertVerifier::AddResultForCert(scoped_refptr<X509Certificate> cert,
119 const CertVerifyResult& verify_result,
120 int rv) {
121 AddResultForCertAndHost(std::move(cert), "*", verify_result, rv);
122 }
123
AddResultForCertAndHost(scoped_refptr<X509Certificate> cert,const std::string & host_pattern,const CertVerifyResult & verify_result,int rv)124 void MockCertVerifier::AddResultForCertAndHost(
125 scoped_refptr<X509Certificate> cert,
126 const std::string& host_pattern,
127 const CertVerifyResult& verify_result,
128 int rv) {
129 rules_.push_back(Rule(std::move(cert), host_pattern, verify_result, rv));
130 }
131
ClearRules()132 void MockCertVerifier::ClearRules() {
133 rules_.clear();
134 }
135
SimulateOnCertVerifierChanged()136 void MockCertVerifier::SimulateOnCertVerifierChanged() {
137 for (Observer& observer : observers_) {
138 observer.OnCertVerifierChanged();
139 }
140 }
141
VerifyImpl(const RequestParams & params,CertVerifyResult * verify_result)142 int MockCertVerifier::VerifyImpl(const RequestParams& params,
143 CertVerifyResult* verify_result) {
144 for (const Rule& rule : rules_) {
145 // Check just the server cert. Intermediates will be ignored.
146 if (!rule.cert->EqualsExcludingChain(params.certificate().get()))
147 continue;
148 if (!base::MatchPattern(params.hostname(), rule.hostname))
149 continue;
150 *verify_result = rule.result;
151 return rule.rv;
152 }
153
154 // Fall through to the default.
155 verify_result->verified_cert = params.certificate();
156 verify_result->cert_status = MapNetErrorToCertStatus(default_result_);
157 return default_result_;
158 }
159
CertVerifierObserverCounter(CertVerifier * verifier)160 CertVerifierObserverCounter::CertVerifierObserverCounter(
161 CertVerifier* verifier) {
162 obs_.Observe(verifier);
163 }
164
165 CertVerifierObserverCounter::~CertVerifierObserverCounter() = default;
166
OnCertVerifierChanged()167 void CertVerifierObserverCounter::OnCertVerifierChanged() {
168 change_count_++;
169 }
170
171 } // namespace net
172