1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/base/dnssec_keyset.h"
6
7 #include <cryptohi.h>
8 #include <cryptoht.h>
9 #include <keyhi.h>
10
11 #include "base/logging.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/time.h"
14 #include "crypto/nss_util.h"
15 #include "net/base/dns_util.h"
16
17 namespace {
18
19 // These are encoded AlgorithmIdentifiers for the given signature algorithm.
20 const unsigned char kRSAWithSHA1[] = {
21 0x30, 0xd, 0x6, 0x9, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0xd, 0x1, 0x1, 0x5, 5, 0
22 };
23
24 const unsigned char kRSAWithSHA256[] = {
25 0x30, 0xd, 0x6, 0x9, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0xd, 0x1, 0x1, 0xb, 5, 0
26 };
27
28 } // namespace
29
30 namespace net {
31
DNSSECKeySet()32 DNSSECKeySet::DNSSECKeySet()
33 : ignore_timestamps_(false) {
34 }
35
~DNSSECKeySet()36 DNSSECKeySet::~DNSSECKeySet() {
37 }
38
AddKey(const base::StringPiece & dnskey)39 bool DNSSECKeySet::AddKey(const base::StringPiece& dnskey) {
40 uint16 keyid = DNSKEYToKeyID(dnskey);
41 std::string der_encoded = ASN1WrapDNSKEY(dnskey);
42 if (der_encoded.empty())
43 return false;
44
45 keyids_.push_back(keyid);
46 public_keys_.push_back(der_encoded);
47 return true;
48 }
49
CheckSignature(const base::StringPiece & name,const base::StringPiece & zone,const base::StringPiece & signature,uint16 rrtype,const std::vector<base::StringPiece> & rrdatas)50 bool DNSSECKeySet::CheckSignature(
51 const base::StringPiece& name,
52 const base::StringPiece& zone,
53 const base::StringPiece& signature,
54 uint16 rrtype,
55 const std::vector<base::StringPiece>& rrdatas) {
56 // signature has this format:
57 // algorithm uint8
58 // labels uint8
59 // ttl uint32
60 // expires uint32
61 // begins uint32
62 // keyid uint16
63 //
64 // followed by the actual signature.
65 if (signature.size() < 16)
66 return false;
67 const unsigned char* sigdata =
68 reinterpret_cast<const unsigned char*>(signature.data());
69
70 uint8 algorithm = sigdata[0];
71 uint32 expires = static_cast<uint32>(sigdata[6]) << 24 |
72 static_cast<uint32>(sigdata[7]) << 16 |
73 static_cast<uint32>(sigdata[8]) << 8 |
74 static_cast<uint32>(sigdata[9]);
75 uint32 begins = static_cast<uint32>(sigdata[10]) << 24 |
76 static_cast<uint32>(sigdata[11]) << 16 |
77 static_cast<uint32>(sigdata[12]) << 8 |
78 static_cast<uint32>(sigdata[13]);
79 uint16 keyid = static_cast<uint16>(sigdata[14]) << 8 |
80 static_cast<uint16>(sigdata[15]);
81
82 if (!ignore_timestamps_) {
83 uint32 now = static_cast<uint32>(base::Time::Now().ToTimeT());
84 if (now < begins || now >= expires)
85 return false;
86 }
87
88 base::StringPiece sig(signature.data() + 16, signature.size() - 16);
89
90 // You should have RFC 4034, 3.1.8.1 open when reading this code.
91 unsigned signed_data_len = 0;
92 signed_data_len += 2; // rrtype
93 signed_data_len += 16; // (see signature format, above)
94 signed_data_len += zone.size();
95
96 for (std::vector<base::StringPiece>::const_iterator
97 i = rrdatas.begin(); i != rrdatas.end(); i++) {
98 signed_data_len += name.size();
99 signed_data_len += 2; // rrtype
100 signed_data_len += 2; // class
101 signed_data_len += 4; // ttl
102 signed_data_len += 2; // RRDATA length
103 signed_data_len += i->size();
104 }
105
106 scoped_array<unsigned char> signed_data(new unsigned char[signed_data_len]);
107 unsigned j = 0;
108
109 signed_data[j++] = static_cast<uint8>(rrtype >> 8);
110 signed_data[j++] = static_cast<uint8>(rrtype);
111 memcpy(&signed_data[j], sigdata, 16);
112 j += 16;
113 memcpy(&signed_data[j], zone.data(), zone.size());
114 j += zone.size();
115
116 for (std::vector<base::StringPiece>::const_iterator
117 i = rrdatas.begin(); i != rrdatas.end(); i++) {
118 memcpy(&signed_data[j], name.data(), name.size());
119 j += name.size();
120 signed_data[j++] = static_cast<uint8>(rrtype >> 8);
121 signed_data[j++] = static_cast<uint8>(rrtype);
122 signed_data[j++] = 0; // CLASS (always IN = 1)
123 signed_data[j++] = 1;
124 // Copy the TTL from |signature|.
125 memcpy(&signed_data[j], signature.data() + 2, sizeof(uint32));
126 j += sizeof(uint32);
127 unsigned rrdata_len = i->size();
128 signed_data[j++] = rrdata_len >> 8;
129 signed_data[j++] = rrdata_len;
130 memcpy(&signed_data[j], i->data(), i->size());
131 j += i->size();
132 }
133
134 DCHECK_EQ(j, signed_data_len);
135
136 base::StringPiece signature_algorithm;
137 if (algorithm == kDNSSEC_RSA_SHA1 ||
138 algorithm == kDNSSEC_RSA_SHA1_NSEC3) {
139 signature_algorithm = base::StringPiece(
140 reinterpret_cast<const char*>(kRSAWithSHA1),
141 sizeof(kRSAWithSHA1));
142 } else if (algorithm == kDNSSEC_RSA_SHA256) {
143 signature_algorithm = base::StringPiece(
144 reinterpret_cast<const char*>(kRSAWithSHA256),
145 sizeof(kRSAWithSHA256));
146 } else {
147 // Unknown algorithm.
148 return false;
149 }
150
151 // Check the signature with each trusted key which has a matching keyid.
152 DCHECK_EQ(public_keys_.size(), keyids_.size());
153 for (unsigned i = 0; i < public_keys_.size(); i++) {
154 if (keyids_[i] != keyid)
155 continue;
156
157 if (VerifySignature(
158 signature_algorithm, sig, public_keys_[i],
159 base::StringPiece(reinterpret_cast<const char*>(signed_data.get()),
160 signed_data_len))) {
161 return true;
162 }
163 }
164
165 return false;
166 }
167
168 // static
DNSKEYToKeyID(const base::StringPiece & dnskey)169 uint16 DNSSECKeySet::DNSKEYToKeyID(const base::StringPiece& dnskey) {
170 const unsigned char* data =
171 reinterpret_cast<const unsigned char*>(dnskey.data());
172
173 // RFC 4043: App B
174 uint32 ac = 0;
175 for (unsigned i = 0; i < dnskey.size(); i++) {
176 if (i & 1) {
177 ac += data[i];
178 } else {
179 ac += static_cast<uint32>(data[i]) << 8;
180 }
181 }
182 ac += (ac >> 16) & 0xffff;
183 return ac;
184 }
185
IgnoreTimestamps()186 void DNSSECKeySet::IgnoreTimestamps() {
187 ignore_timestamps_ = true;
188 }
189
VerifySignature(base::StringPiece signature_algorithm,base::StringPiece signature,base::StringPiece public_key,base::StringPiece signed_data)190 bool DNSSECKeySet::VerifySignature(
191 base::StringPiece signature_algorithm,
192 base::StringPiece signature,
193 base::StringPiece public_key,
194 base::StringPiece signed_data) {
195 // This code is largely a copy-and-paste from
196 // crypto/signature_verifier_nss.cc. We can't change
197 // crypto::SignatureVerifier to always use NSS because we want the ability to
198 // be FIPS 140-2 compliant. However, we can't use crypto::SignatureVerifier
199 // here because some platforms don't support SHA256 signatures. Therefore, we
200 // use NSS directly.
201
202 crypto::EnsureNSSInit();
203
204 CERTSubjectPublicKeyInfo* spki = NULL;
205 SECItem spki_der;
206 spki_der.type = siBuffer;
207 spki_der.data = (uint8*) public_key.data();
208 spki_der.len = public_key.size();
209 spki = SECKEY_DecodeDERSubjectPublicKeyInfo(&spki_der);
210 if (!spki)
211 return false;
212 SECKEYPublicKey* pub_key = SECKEY_ExtractPublicKey(spki);
213 SECKEY_DestroySubjectPublicKeyInfo(spki); // Done with spki.
214 if (!pub_key)
215 return false;
216
217 PLArenaPool* arena = PORT_NewArena(DER_DEFAULT_CHUNKSIZE);
218 if (!arena) {
219 SECKEY_DestroyPublicKey(pub_key);
220 return false;
221 }
222
223 SECItem sig_alg_der;
224 sig_alg_der.type = siBuffer;
225 sig_alg_der.data = (uint8*) signature_algorithm.data();
226 sig_alg_der.len = signature_algorithm.size();
227 SECAlgorithmID sig_alg_id;
228 SECStatus rv;
229 rv = SEC_QuickDERDecodeItem(arena, &sig_alg_id, SECOID_AlgorithmIDTemplate,
230 &sig_alg_der);
231 if (rv != SECSuccess) {
232 SECKEY_DestroyPublicKey(pub_key);
233 PORT_FreeArena(arena, PR_TRUE);
234 return false;
235 }
236
237 SECItem sig;
238 sig.type = siBuffer;
239 sig.data = (uint8*) signature.data();
240 sig.len = signature.size();
241 SECOidTag hash_alg_tag;
242 VFYContext* vfy_context =
243 VFY_CreateContextWithAlgorithmID(pub_key, &sig,
244 &sig_alg_id, &hash_alg_tag,
245 NULL);
246 SECKEY_DestroyPublicKey(pub_key);
247 PORT_FreeArena(arena, PR_TRUE); // Done with sig_alg_id.
248 if (!vfy_context) {
249 // A corrupted RSA signature could be detected without the data, so
250 // VFY_CreateContextWithAlgorithmID may fail with SEC_ERROR_BAD_SIGNATURE
251 // (-8182).
252 return false;
253 }
254
255 rv = VFY_Begin(vfy_context);
256 if (rv != SECSuccess) {
257 NOTREACHED();
258 return false;
259 }
260 rv = VFY_Update(vfy_context, (uint8*) signed_data.data(), signed_data.size());
261 if (rv != SECSuccess) {
262 NOTREACHED();
263 return false;
264 }
265 rv = VFY_End(vfy_context);
266 VFY_DestroyContext(vfy_context, PR_TRUE);
267
268 return rv == SECSuccess;
269 }
270
271 // This is an ASN.1 encoded AlgorithmIdentifier for RSA
272 static const unsigned char kASN1AlgorithmIdentifierRSA[] = {
273 0x30, // SEQUENCE
274 0x0d, // length (11 bytes)
275 0x06, // OBJECT IDENTIFER
276 0x09, // length (9 bytes)
277 // OID 1.2.840.113549.1.1.1 (RSA)
278 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01,
279 // NULL of length 0
280 0x05, 0x00,
281 };
282
283 // EncodeASN1Length assumes that |*length| contains the number of DER-encoded,
284 // length-prefixed ASN.1 bytes to follow and serialises the length to |out[*j]|
285 // and updates |j| and |length| accordingly.
EncodeASN1Length(unsigned char * out,unsigned * j,unsigned * length)286 static void EncodeASN1Length(unsigned char* out, unsigned* j,
287 unsigned* length) {
288 if ((*length - 1) < 128) {
289 (*length) -= 1;
290 out[(*j)++] = *length;
291 } else if ((*length - 2) < 256) {
292 (*length) -= 2;
293 out[(*j)++] = 0x80 | 1;
294 out[(*j)++] = *length;
295 } else {
296 (*length) -= 3;
297 out[(*j)++] = 0x80 | 2;
298 out[(*j)++] = *length >> 8;
299 out[(*j)++] = *length;
300 }
301 }
302
303 // AdvanceForASN1Length returns the number of bytes required to encode a ASN1
304 // DER length value of |remaining|.
AdvanceForASN1Length(unsigned remaining)305 static unsigned AdvanceForASN1Length(unsigned remaining) {
306 if (remaining < 128) {
307 return 1;
308 } else if (remaining < 256) {
309 return 2;
310 } else if (remaining < 65536) {
311 return 3;
312 } else {
313 NOTREACHED();
314 return 3;
315 }
316 }
317
318 // ASN1WrapDNSKEY converts the DNSKEY RDATA in |dnskey| into the ASN.1 wrapped
319 // format expected by NSS. To wit:
320 // SubjectPublicKeyInfo ::= SEQUENCE {
321 // algorithm AlgorithmIdentifier,
322 // subjectPublicKey BIT STRING }
ASN1WrapDNSKEY(const base::StringPiece & dnskey)323 std::string DNSSECKeySet::ASN1WrapDNSKEY(const base::StringPiece& dnskey) {
324 const unsigned char* data =
325 reinterpret_cast<const unsigned char*>(dnskey.data());
326
327 if (dnskey.size() < 5 || dnskey.size() > 32767)
328 return "";
329 const uint8 algorithm = data[3];
330 if (algorithm != kDNSSEC_RSA_SHA1 &&
331 algorithm != kDNSSEC_RSA_SHA1_NSEC3 &&
332 algorithm != kDNSSEC_RSA_SHA256) {
333 return "";
334 }
335
336 unsigned exp_length;
337 unsigned exp_offset;
338 // First we extract the public exponent.
339 if (data[4] == 0) {
340 if (dnskey.size() < 7)
341 return "";
342 exp_length = static_cast<unsigned>(data[5]) << 8 |
343 static_cast<unsigned>(data[6]);
344 exp_offset = 7;
345 } else {
346 exp_length = static_cast<unsigned>(data[4]);
347 exp_offset = 5;
348 }
349
350 // We refuse to deal with large public exponents.
351 if (exp_length > 3)
352 return "";
353 if (dnskey.size() < exp_offset + exp_length)
354 return "";
355
356 unsigned exp = 0;
357 for (unsigned i = 0; i < exp_length; i++) {
358 exp <<= 8;
359 exp |= static_cast<unsigned>(data[exp_offset + i]);
360 }
361
362 unsigned n_offset = exp_offset + exp_length;
363 unsigned n_length = dnskey.size() - n_offset;
364
365 // Anything smaller than 512 bits is too weak to be trusted.
366 if (n_length < 64)
367 return "";
368
369 // If the MSB of exp is true then we need to prefix a zero byte to stop the
370 // ASN.1 encoding from being negative.
371 if (exp & (1 << ((8 * exp_length) - 1)))
372 exp_length++;
373
374 // Likewise with the modulus
375 unsigned n_padding = data[n_offset] & 0x80 ? 1 : 0;
376
377 // We now calculate the length of the full ASN.1 encoded public key. We're
378 // working backwards from the end of the structure. Keep in mind that it's:
379 // SEQUENCE
380 // AlgorithmIdentifier
381 // BITSTRING
382 // SEQUENCE
383 // INTEGER
384 // INTEGER
385 unsigned length = 0;
386 length += exp_length; // exponent data
387 length++; // we know that |exp_length| < 128
388 length++; // INTEGER tag for exponent
389 length += n_length + n_padding;
390 length += AdvanceForASN1Length(n_length + n_padding);
391 length++; // INTEGER tag for modulus
392 length += AdvanceForASN1Length(length); // SEQUENCE length
393 length++; // SEQUENCE tag
394 length++; // BITSTRING unused bits
395 length += AdvanceForASN1Length(length); // BITSTRING length
396 length++; // BITSTRING tag
397 length += sizeof(kASN1AlgorithmIdentifierRSA);
398 length += AdvanceForASN1Length(length); // SEQUENCE length
399 length++; // SEQUENCE tag
400
401 scoped_array<unsigned char> out(new unsigned char[length]);
402
403 // Now we walk forwards and serialise the ASN.1, undoing the steps above.
404 unsigned j = 0;
405 out[j++] = 0x30; // SEQUENCE
406 length--;
407 EncodeASN1Length(out.get(), &j, &length);
408 memcpy(&out[j], kASN1AlgorithmIdentifierRSA,
409 sizeof(kASN1AlgorithmIdentifierRSA));
410 j += sizeof(kASN1AlgorithmIdentifierRSA);
411 length -= sizeof(kASN1AlgorithmIdentifierRSA);
412 out[j++] = 3; // BITSTRING tag
413 length--;
414 EncodeASN1Length(out.get(), &j, &length);
415 out[j++] = 0; // BITSTRING unused bits
416 length--;
417 out[j++] = 0x30; // SEQUENCE
418 length--;
419 EncodeASN1Length(out.get(), &j, &length);
420 out[j++] = 2; // INTEGER
421 length--;
422 unsigned l = n_length + n_padding;
423 if (l < 128) {
424 out[j++] = l;
425 length--;
426 } else if (l < 256) {
427 out[j++] = 0x80 | 1;
428 out[j++] = l;
429 length -= 2;
430 } else if (l < 65536) {
431 out[j++] = 0x80 | 2;
432 out[j++] = l >> 8;
433 out[j++] = l;
434 length -= 3;
435 } else {
436 NOTREACHED();
437 }
438
439 if (n_padding) {
440 out[j++] = 0;
441 length--;
442 }
443 memcpy(&out[j], &data[n_offset], n_length);
444 j += n_length;
445 length -= n_length;
446 out[j++] = 2; // INTEGER
447 length--;
448 out[j++] = exp_length;
449 length--;
450 for (unsigned i = exp_length - 1; i < exp_length; i--) {
451 out[j++] = exp >> (8 * i);
452 length--;
453 }
454
455 DCHECK_EQ(0u, length);
456
457 return std::string(reinterpret_cast<char*>(out.get()), j);
458 }
459
460 } // namespace net
461