1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3#*************************************************************************** 4# _ _ ____ _ 5# Project ___| | | | _ \| | 6# / __| | | | |_) | | 7# | (__| |_| | _ <| |___ 8# \___|\___/|_| \_\_____| 9# 10# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al. 11# 12# This software is licensed as described in the file COPYING, which 13# you should have received as part of this distribution. The terms 14# are also available at https://curl.se/docs/copyright.html. 15# 16# You may opt to use, copy, modify, merge, publish, distribute and/or sell 17# copies of the Software, and permit persons to whom the Software is 18# furnished to do so, under the terms of the COPYING file. 19# 20# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY 21# KIND, either express or implied. 22# 23# SPDX-License-Identifier: curl 24# 25########################################################################### 26# 27import os 28import re 29from datetime import timedelta, datetime 30from typing import List, Any, Optional 31 32from cryptography import x509 33from cryptography.hazmat.backends import default_backend 34from cryptography.hazmat.primitives import hashes 35from cryptography.hazmat.primitives.asymmetric import ec, rsa 36from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey 37from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey 38from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key 39from cryptography.x509 import ExtendedKeyUsageOID, NameOID 40 41 42EC_SUPPORTED = {} 43EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [ 44 ec.SECP192R1, 45 ec.SECP224R1, 46 ec.SECP256R1, 47 ec.SECP384R1, 48]]) 49 50 51def _private_key(key_type): 52 if isinstance(key_type, str): 53 key_type = key_type.upper() 54 m = re.match(r'^(RSA)?(\d+)$', key_type) 55 if m: 56 key_type = int(m.group(2)) 57 58 if isinstance(key_type, int): 59 return rsa.generate_private_key( 60 public_exponent=65537, 61 key_size=key_type, 62 backend=default_backend() 63 ) 64 if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED: 65 key_type = EC_SUPPORTED[key_type] 66 return ec.generate_private_key( 67 curve=key_type, 68 backend=default_backend() 69 ) 70 71 72class CertificateSpec: 73 74 def __init__(self, name: Optional[str] = None, 75 domains: Optional[List[str]] = None, 76 email: Optional[str] = None, 77 key_type: Optional[str] = None, 78 single_file: bool = False, 79 valid_from: timedelta = timedelta(days=-1), 80 valid_to: timedelta = timedelta(days=89), 81 client: bool = False, 82 sub_specs: Optional[List['CertificateSpec']] = None): 83 self._name = name 84 self.domains = domains 85 self.client = client 86 self.email = email 87 self.key_type = key_type 88 self.single_file = single_file 89 self.valid_from = valid_from 90 self.valid_to = valid_to 91 self.sub_specs = sub_specs 92 93 @property 94 def name(self) -> Optional[str]: 95 if self._name: 96 return self._name 97 elif self.domains: 98 return self.domains[0] 99 return None 100 101 @property 102 def type(self) -> Optional[str]: 103 if self.domains and len(self.domains): 104 return "server" 105 elif self.client: 106 return "client" 107 elif self.name: 108 return "ca" 109 return None 110 111 112class Credentials: 113 114 def __init__(self, 115 name: str, 116 cert: Any, 117 pkey: Any, 118 issuer: Optional['Credentials'] = None): 119 self._name = name 120 self._cert = cert 121 self._pkey = pkey 122 self._issuer = issuer 123 self._cert_file = None 124 self._pkey_file = None 125 self._store = None 126 127 @property 128 def name(self) -> str: 129 return self._name 130 131 @property 132 def subject(self) -> x509.Name: 133 return self._cert.subject 134 135 @property 136 def key_type(self): 137 if isinstance(self._pkey, RSAPrivateKey): 138 return f"rsa{self._pkey.key_size}" 139 elif isinstance(self._pkey, EllipticCurvePrivateKey): 140 return f"{self._pkey.curve.name}" 141 else: 142 raise Exception(f"unknown key type: {self._pkey}") 143 144 @property 145 def private_key(self) -> Any: 146 return self._pkey 147 148 @property 149 def certificate(self) -> Any: 150 return self._cert 151 152 @property 153 def cert_pem(self) -> bytes: 154 return self._cert.public_bytes(Encoding.PEM) 155 156 @property 157 def pkey_pem(self) -> bytes: 158 return self._pkey.private_bytes( 159 Encoding.PEM, 160 PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8, 161 NoEncryption()) 162 163 @property 164 def issuer(self) -> Optional['Credentials']: 165 return self._issuer 166 167 def set_store(self, store: 'CertStore'): 168 self._store = store 169 170 def set_files(self, cert_file: str, pkey_file: Optional[str] = None, 171 combined_file: Optional[str] = None): 172 self._cert_file = cert_file 173 self._pkey_file = pkey_file 174 self._combined_file = combined_file 175 176 @property 177 def cert_file(self) -> str: 178 return self._cert_file 179 180 @property 181 def pkey_file(self) -> Optional[str]: 182 return self._pkey_file 183 184 @property 185 def combined_file(self) -> Optional[str]: 186 return self._combined_file 187 188 def get_first(self, name) -> Optional['Credentials']: 189 creds = self._store.get_credentials_for_name(name) if self._store else [] 190 return creds[0] if len(creds) else None 191 192 def get_credentials_for_name(self, name) -> List['Credentials']: 193 return self._store.get_credentials_for_name(name) if self._store else [] 194 195 def issue_certs(self, specs: List[CertificateSpec], 196 chain: Optional[List['Credentials']] = None) -> List['Credentials']: 197 return [self.issue_cert(spec=spec, chain=chain) for spec in specs] 198 199 def issue_cert(self, spec: CertificateSpec, 200 chain: Optional[List['Credentials']] = None) -> 'Credentials': 201 key_type = spec.key_type if spec.key_type else self.key_type 202 creds = None 203 if self._store: 204 creds = self._store.load_credentials( 205 name=spec.name, key_type=key_type, single_file=spec.single_file, issuer=self) 206 if creds is None: 207 creds = TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type, 208 valid_from=spec.valid_from, valid_to=spec.valid_to) 209 if self._store: 210 self._store.save(creds, single_file=spec.single_file) 211 if spec.type == "ca": 212 self._store.save_chain(creds, "ca", with_root=True) 213 214 if spec.sub_specs: 215 if self._store: 216 sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name)) 217 creds.set_store(sub_store) 218 subchain = chain.copy() if chain else [] 219 subchain.append(self) 220 creds.issue_certs(spec.sub_specs, chain=subchain) 221 return creds 222 223 224class CertStore: 225 226 def __init__(self, fpath: str): 227 self._store_dir = fpath 228 if not os.path.exists(self._store_dir): 229 os.makedirs(self._store_dir) 230 self._creds_by_name = {} 231 232 @property 233 def path(self) -> str: 234 return self._store_dir 235 236 def save(self, creds: Credentials, name: Optional[str] = None, 237 chain: Optional[List[Credentials]] = None, 238 single_file: bool = False) -> None: 239 name = name if name is not None else creds.name 240 cert_file = self.get_cert_file(name=name, key_type=creds.key_type) 241 pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type) 242 comb_file = self.get_combined_file(name=name, key_type=creds.key_type) 243 if single_file: 244 pkey_file = None 245 with open(cert_file, "wb") as fd: 246 fd.write(creds.cert_pem) 247 if chain: 248 for c in chain: 249 fd.write(c.cert_pem) 250 if pkey_file is None: 251 fd.write(creds.pkey_pem) 252 if pkey_file is not None: 253 with open(pkey_file, "wb") as fd: 254 fd.write(creds.pkey_pem) 255 with open(comb_file, "wb") as fd: 256 fd.write(creds.cert_pem) 257 if chain: 258 for c in chain: 259 fd.write(c.cert_pem) 260 fd.write(creds.pkey_pem) 261 creds.set_files(cert_file, pkey_file, comb_file) 262 self._add_credentials(name, creds) 263 264 def save_chain(self, creds: Credentials, infix: str, with_root=False): 265 name = creds.name 266 chain = [creds] 267 while creds.issuer is not None: 268 creds = creds.issuer 269 chain.append(creds) 270 if not with_root and len(chain) > 1: 271 chain = chain[:-1] 272 chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem') 273 with open(chain_file, "wb") as fd: 274 for c in chain: 275 fd.write(c.cert_pem) 276 277 def _add_credentials(self, name: str, creds: Credentials): 278 if name not in self._creds_by_name: 279 self._creds_by_name[name] = [] 280 self._creds_by_name[name].append(creds) 281 282 def get_credentials_for_name(self, name) -> List[Credentials]: 283 return self._creds_by_name[name] if name in self._creds_by_name else [] 284 285 def get_cert_file(self, name: str, key_type=None) -> str: 286 key_infix = ".{0}".format(key_type) if key_type is not None else "" 287 return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem') 288 289 def get_pkey_file(self, name: str, key_type=None) -> str: 290 key_infix = ".{0}".format(key_type) if key_type is not None else "" 291 return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem') 292 293 def get_combined_file(self, name: str, key_type=None) -> str: 294 return os.path.join(self._store_dir, f'{name}.pem') 295 296 def load_pem_cert(self, fpath: str) -> x509.Certificate: 297 with open(fpath) as fd: 298 return x509.load_pem_x509_certificate("".join(fd.readlines()).encode()) 299 300 def load_pem_pkey(self, fpath: str): 301 with open(fpath) as fd: 302 return load_pem_private_key("".join(fd.readlines()).encode(), password=None) 303 304 def load_credentials(self, name: str, key_type=None, 305 single_file: bool = False, 306 issuer: Optional[Credentials] = None): 307 cert_file = self.get_cert_file(name=name, key_type=key_type) 308 pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type) 309 comb_file = self.get_combined_file(name=name, key_type=key_type) 310 if os.path.isfile(cert_file) and os.path.isfile(pkey_file): 311 cert = self.load_pem_cert(cert_file) 312 pkey = self.load_pem_pkey(pkey_file) 313 creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 314 creds.set_store(self) 315 creds.set_files(cert_file, pkey_file, comb_file) 316 self._add_credentials(name, creds) 317 return creds 318 return None 319 320 321class TestCA: 322 323 @classmethod 324 def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials: 325 store = CertStore(fpath=store_dir) 326 creds = store.load_credentials(name="ca", key_type=key_type, issuer=None) 327 if creds is None: 328 creds = TestCA._make_ca_credentials(name=name, key_type=key_type) 329 store.save(creds, name="ca") 330 creds.set_store(store) 331 return creds 332 333 @staticmethod 334 def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any, 335 valid_from: timedelta = timedelta(days=-1), 336 valid_to: timedelta = timedelta(days=89), 337 ) -> Credentials: 338 """Create a certificate signed by this CA for the given domains. 339 :returns: the certificate and private key PEM file paths 340 """ 341 if spec.domains and len(spec.domains): 342 creds = TestCA._make_server_credentials(name=spec.name, domains=spec.domains, 343 issuer=issuer, valid_from=valid_from, 344 valid_to=valid_to, key_type=key_type) 345 elif spec.client: 346 creds = TestCA._make_client_credentials(name=spec.name, issuer=issuer, 347 email=spec.email, valid_from=valid_from, 348 valid_to=valid_to, key_type=key_type) 349 elif spec.name: 350 creds = TestCA._make_ca_credentials(name=spec.name, issuer=issuer, 351 valid_from=valid_from, valid_to=valid_to, 352 key_type=key_type) 353 else: 354 raise Exception(f"unrecognized certificate specification: {spec}") 355 return creds 356 357 @staticmethod 358 def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name: 359 name_pieces = [] 360 if org_name: 361 oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME 362 name_pieces.append(x509.NameAttribute(oid, org_name)) 363 elif common_name: 364 name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name)) 365 if parent: 366 name_pieces.extend([rdn for rdn in parent]) 367 return x509.Name(name_pieces) 368 369 @staticmethod 370 def _make_csr( 371 subject: x509.Name, 372 pkey: Any, 373 issuer_subject: Optional[Credentials], 374 valid_from_delta: timedelta = None, 375 valid_until_delta: timedelta = None 376 ): 377 pubkey = pkey.public_key() 378 issuer_subject = issuer_subject if issuer_subject is not None else subject 379 380 valid_from = datetime.now() 381 if valid_until_delta is not None: 382 valid_from += valid_from_delta 383 valid_until = datetime.now() 384 if valid_until_delta is not None: 385 valid_until += valid_until_delta 386 387 return ( 388 x509.CertificateBuilder() 389 .subject_name(subject) 390 .issuer_name(issuer_subject) 391 .public_key(pubkey) 392 .not_valid_before(valid_from) 393 .not_valid_after(valid_until) 394 .serial_number(x509.random_serial_number()) 395 .add_extension( 396 x509.SubjectKeyIdentifier.from_public_key(pubkey), 397 critical=False, 398 ) 399 ) 400 401 @staticmethod 402 def _add_ca_usages(csr: Any) -> Any: 403 return csr.add_extension( 404 x509.BasicConstraints(ca=True, path_length=9), 405 critical=True, 406 ).add_extension( 407 x509.KeyUsage( 408 digital_signature=True, 409 content_commitment=False, 410 key_encipherment=False, 411 data_encipherment=False, 412 key_agreement=False, 413 key_cert_sign=True, 414 crl_sign=True, 415 encipher_only=False, 416 decipher_only=False), 417 critical=True 418 ).add_extension( 419 x509.ExtendedKeyUsage([ 420 ExtendedKeyUsageOID.CLIENT_AUTH, 421 ExtendedKeyUsageOID.SERVER_AUTH, 422 ExtendedKeyUsageOID.CODE_SIGNING, 423 ]), 424 critical=True 425 ) 426 427 @staticmethod 428 def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any: 429 return csr.add_extension( 430 x509.BasicConstraints(ca=False, path_length=None), 431 critical=True, 432 ).add_extension( 433 x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( 434 issuer.certificate.extensions.get_extension_for_class( 435 x509.SubjectKeyIdentifier).value), 436 critical=False 437 ).add_extension( 438 x509.SubjectAlternativeName([x509.DNSName(domain) for domain in domains]), 439 critical=True, 440 ).add_extension( 441 x509.ExtendedKeyUsage([ 442 ExtendedKeyUsageOID.SERVER_AUTH, 443 ]), 444 critical=False 445 ) 446 447 @staticmethod 448 def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any: 449 cert = csr.add_extension( 450 x509.BasicConstraints(ca=False, path_length=None), 451 critical=True, 452 ).add_extension( 453 x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( 454 issuer.certificate.extensions.get_extension_for_class( 455 x509.SubjectKeyIdentifier).value), 456 critical=False 457 ) 458 if rfc82name: 459 cert.add_extension( 460 x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]), 461 critical=True, 462 ) 463 cert.add_extension( 464 x509.ExtendedKeyUsage([ 465 ExtendedKeyUsageOID.CLIENT_AUTH, 466 ]), 467 critical=True 468 ) 469 return cert 470 471 @staticmethod 472 def _make_ca_credentials(name, key_type: Any, 473 issuer: Credentials = None, 474 valid_from: timedelta = timedelta(days=-1), 475 valid_to: timedelta = timedelta(days=89), 476 ) -> Credentials: 477 pkey = _private_key(key_type=key_type) 478 if issuer is not None: 479 issuer_subject = issuer.certificate.subject 480 issuer_key = issuer.private_key 481 else: 482 issuer_subject = None 483 issuer_key = pkey 484 subject = TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None) 485 csr = TestCA._make_csr(subject=subject, 486 issuer_subject=issuer_subject, pkey=pkey, 487 valid_from_delta=valid_from, valid_until_delta=valid_to) 488 csr = TestCA._add_ca_usages(csr) 489 cert = csr.sign(private_key=issuer_key, 490 algorithm=hashes.SHA256(), 491 backend=default_backend()) 492 return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 493 494 @staticmethod 495 def _make_server_credentials(name: str, domains: List[str], issuer: Credentials, 496 key_type: Any, 497 valid_from: timedelta = timedelta(days=-1), 498 valid_to: timedelta = timedelta(days=89), 499 ) -> Credentials: 500 name = name 501 pkey = _private_key(key_type=key_type) 502 subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject) 503 csr = TestCA._make_csr(subject=subject, 504 issuer_subject=issuer.certificate.subject, pkey=pkey, 505 valid_from_delta=valid_from, valid_until_delta=valid_to) 506 csr = TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer) 507 cert = csr.sign(private_key=issuer.private_key, 508 algorithm=hashes.SHA256(), 509 backend=default_backend()) 510 return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 511 512 @staticmethod 513 def _make_client_credentials(name: str, 514 issuer: Credentials, email: Optional[str], 515 key_type: Any, 516 valid_from: timedelta = timedelta(days=-1), 517 valid_to: timedelta = timedelta(days=89), 518 ) -> Credentials: 519 pkey = _private_key(key_type=key_type) 520 subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject) 521 csr = TestCA._make_csr(subject=subject, 522 issuer_subject=issuer.certificate.subject, pkey=pkey, 523 valid_from_delta=valid_from, valid_until_delta=valid_to) 524 csr = TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email) 525 cert = csr.sign(private_key=issuer.private_key, 526 algorithm=hashes.SHA256(), 527 backend=default_backend()) 528 return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 529