• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#***************************************************************************
4#                                  _   _ ____  _
5#  Project                     ___| | | |  _ \| |
6#                             / __| | | | |_) | |
7#                            | (__| |_| |  _ <| |___
8#                             \___|\___/|_| \_\_____|
9#
10# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
11#
12# This software is licensed as described in the file COPYING, which
13# you should have received as part of this distribution. The terms
14# are also available at https://curl.se/docs/copyright.html.
15#
16# You may opt to use, copy, modify, merge, publish, distribute and/or sell
17# copies of the Software, and permit persons to whom the Software is
18# furnished to do so, under the terms of the COPYING file.
19#
20# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
21# KIND, either express or implied.
22#
23# SPDX-License-Identifier: curl
24#
25###########################################################################
26#
27import os
28import re
29from datetime import timedelta, datetime
30from typing import List, Any, Optional
31
32from cryptography import x509
33from cryptography.hazmat.backends import default_backend
34from cryptography.hazmat.primitives import hashes
35from cryptography.hazmat.primitives.asymmetric import ec, rsa
36from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
37from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
38from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key
39from cryptography.x509 import ExtendedKeyUsageOID, NameOID
40
41
42EC_SUPPORTED = {}
43EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [
44    ec.SECP192R1,
45    ec.SECP224R1,
46    ec.SECP256R1,
47    ec.SECP384R1,
48]])
49
50
51def _private_key(key_type):
52    if isinstance(key_type, str):
53        key_type = key_type.upper()
54        m = re.match(r'^(RSA)?(\d+)$', key_type)
55        if m:
56            key_type = int(m.group(2))
57
58    if isinstance(key_type, int):
59        return rsa.generate_private_key(
60            public_exponent=65537,
61            key_size=key_type,
62            backend=default_backend()
63        )
64    if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED:
65        key_type = EC_SUPPORTED[key_type]
66    return ec.generate_private_key(
67        curve=key_type,
68        backend=default_backend()
69    )
70
71
72class CertificateSpec:
73
74    def __init__(self, name: Optional[str] = None,
75                 domains: Optional[List[str]] = None,
76                 email: Optional[str] = None,
77                 key_type: Optional[str] = None,
78                 single_file: bool = False,
79                 valid_from: timedelta = timedelta(days=-1),
80                 valid_to: timedelta = timedelta(days=89),
81                 client: bool = False,
82                 sub_specs: Optional[List['CertificateSpec']] = None):
83        self._name = name
84        self.domains = domains
85        self.client = client
86        self.email = email
87        self.key_type = key_type
88        self.single_file = single_file
89        self.valid_from = valid_from
90        self.valid_to = valid_to
91        self.sub_specs = sub_specs
92
93    @property
94    def name(self) -> Optional[str]:
95        if self._name:
96            return self._name
97        elif self.domains:
98            return self.domains[0]
99        return None
100
101    @property
102    def type(self) -> Optional[str]:
103        if self.domains and len(self.domains):
104            return "server"
105        elif self.client:
106            return "client"
107        elif self.name:
108            return "ca"
109        return None
110
111
112class Credentials:
113
114    def __init__(self,
115                 name: str,
116                 cert: Any,
117                 pkey: Any,
118                 issuer: Optional['Credentials'] = None):
119        self._name = name
120        self._cert = cert
121        self._pkey = pkey
122        self._issuer = issuer
123        self._cert_file = None
124        self._pkey_file = None
125        self._store = None
126
127    @property
128    def name(self) -> str:
129        return self._name
130
131    @property
132    def subject(self) -> x509.Name:
133        return self._cert.subject
134
135    @property
136    def key_type(self):
137        if isinstance(self._pkey, RSAPrivateKey):
138            return f"rsa{self._pkey.key_size}"
139        elif isinstance(self._pkey, EllipticCurvePrivateKey):
140            return f"{self._pkey.curve.name}"
141        else:
142            raise Exception(f"unknown key type: {self._pkey}")
143
144    @property
145    def private_key(self) -> Any:
146        return self._pkey
147
148    @property
149    def certificate(self) -> Any:
150        return self._cert
151
152    @property
153    def cert_pem(self) -> bytes:
154        return self._cert.public_bytes(Encoding.PEM)
155
156    @property
157    def pkey_pem(self) -> bytes:
158        return self._pkey.private_bytes(
159            Encoding.PEM,
160            PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8,
161            NoEncryption())
162
163    @property
164    def issuer(self) -> Optional['Credentials']:
165        return self._issuer
166
167    def set_store(self, store: 'CertStore'):
168        self._store = store
169
170    def set_files(self, cert_file: str, pkey_file: Optional[str] = None,
171                  combined_file: Optional[str] = None):
172        self._cert_file = cert_file
173        self._pkey_file = pkey_file
174        self._combined_file = combined_file
175
176    @property
177    def cert_file(self) -> str:
178        return self._cert_file
179
180    @property
181    def pkey_file(self) -> Optional[str]:
182        return self._pkey_file
183
184    @property
185    def combined_file(self) -> Optional[str]:
186        return self._combined_file
187
188    def get_first(self, name) -> Optional['Credentials']:
189        creds = self._store.get_credentials_for_name(name) if self._store else []
190        return creds[0] if len(creds) else None
191
192    def get_credentials_for_name(self, name) -> List['Credentials']:
193        return self._store.get_credentials_for_name(name) if self._store else []
194
195    def issue_certs(self, specs: List[CertificateSpec],
196                    chain: Optional[List['Credentials']] = None) -> List['Credentials']:
197        return [self.issue_cert(spec=spec, chain=chain) for spec in specs]
198
199    def issue_cert(self, spec: CertificateSpec,
200                   chain: Optional[List['Credentials']] = None) -> 'Credentials':
201        key_type = spec.key_type if spec.key_type else self.key_type
202        creds = None
203        if self._store:
204            creds = self._store.load_credentials(
205                name=spec.name, key_type=key_type, single_file=spec.single_file, issuer=self)
206        if creds is None:
207            creds = TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type,
208                                              valid_from=spec.valid_from, valid_to=spec.valid_to)
209            if self._store:
210                self._store.save(creds, single_file=spec.single_file)
211                if spec.type == "ca":
212                    self._store.save_chain(creds, "ca", with_root=True)
213
214        if spec.sub_specs:
215            if self._store:
216                sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name))
217                creds.set_store(sub_store)
218            subchain = chain.copy() if chain else []
219            subchain.append(self)
220            creds.issue_certs(spec.sub_specs, chain=subchain)
221        return creds
222
223
224class CertStore:
225
226    def __init__(self, fpath: str):
227        self._store_dir = fpath
228        if not os.path.exists(self._store_dir):
229            os.makedirs(self._store_dir)
230        self._creds_by_name = {}
231
232    @property
233    def path(self) -> str:
234        return self._store_dir
235
236    def save(self, creds: Credentials, name: Optional[str] = None,
237             chain: Optional[List[Credentials]] = None,
238             single_file: bool = False) -> None:
239        name = name if name is not None else creds.name
240        cert_file = self.get_cert_file(name=name, key_type=creds.key_type)
241        pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type)
242        comb_file = self.get_combined_file(name=name, key_type=creds.key_type)
243        if single_file:
244            pkey_file = None
245        with open(cert_file, "wb") as fd:
246            fd.write(creds.cert_pem)
247            if chain:
248                for c in chain:
249                    fd.write(c.cert_pem)
250            if pkey_file is None:
251                fd.write(creds.pkey_pem)
252        if pkey_file is not None:
253            with open(pkey_file, "wb") as fd:
254                fd.write(creds.pkey_pem)
255        with open(comb_file, "wb") as fd:
256            fd.write(creds.cert_pem)
257            if chain:
258                for c in chain:
259                    fd.write(c.cert_pem)
260            fd.write(creds.pkey_pem)
261        creds.set_files(cert_file, pkey_file, comb_file)
262        self._add_credentials(name, creds)
263
264    def save_chain(self, creds: Credentials, infix: str, with_root=False):
265        name = creds.name
266        chain = [creds]
267        while creds.issuer is not None:
268            creds = creds.issuer
269            chain.append(creds)
270        if not with_root and len(chain) > 1:
271            chain = chain[:-1]
272        chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem')
273        with open(chain_file, "wb") as fd:
274            for c in chain:
275                fd.write(c.cert_pem)
276
277    def _add_credentials(self, name: str, creds: Credentials):
278        if name not in self._creds_by_name:
279            self._creds_by_name[name] = []
280        self._creds_by_name[name].append(creds)
281
282    def get_credentials_for_name(self, name) -> List[Credentials]:
283        return self._creds_by_name[name] if name in self._creds_by_name else []
284
285    def get_cert_file(self, name: str, key_type=None) -> str:
286        key_infix = ".{0}".format(key_type) if key_type is not None else ""
287        return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem')
288
289    def get_pkey_file(self, name: str, key_type=None) -> str:
290        key_infix = ".{0}".format(key_type) if key_type is not None else ""
291        return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem')
292
293    def get_combined_file(self, name: str, key_type=None) -> str:
294        return os.path.join(self._store_dir, f'{name}.pem')
295
296    def load_pem_cert(self, fpath: str) -> x509.Certificate:
297        with open(fpath) as fd:
298            return x509.load_pem_x509_certificate("".join(fd.readlines()).encode())
299
300    def load_pem_pkey(self, fpath: str):
301        with open(fpath) as fd:
302            return load_pem_private_key("".join(fd.readlines()).encode(), password=None)
303
304    def load_credentials(self, name: str, key_type=None,
305                         single_file: bool = False,
306                         issuer: Optional[Credentials] = None):
307        cert_file = self.get_cert_file(name=name, key_type=key_type)
308        pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type)
309        comb_file = self.get_combined_file(name=name, key_type=key_type)
310        if os.path.isfile(cert_file) and os.path.isfile(pkey_file):
311            cert = self.load_pem_cert(cert_file)
312            pkey = self.load_pem_pkey(pkey_file)
313            creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
314            creds.set_store(self)
315            creds.set_files(cert_file, pkey_file, comb_file)
316            self._add_credentials(name, creds)
317            return creds
318        return None
319
320
321class TestCA:
322
323    @classmethod
324    def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials:
325        store = CertStore(fpath=store_dir)
326        creds = store.load_credentials(name="ca", key_type=key_type, issuer=None)
327        if creds is None:
328            creds = TestCA._make_ca_credentials(name=name, key_type=key_type)
329            store.save(creds, name="ca")
330            creds.set_store(store)
331        return creds
332
333    @staticmethod
334    def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any,
335                           valid_from: timedelta = timedelta(days=-1),
336                           valid_to: timedelta = timedelta(days=89),
337                           ) -> Credentials:
338        """Create a certificate signed by this CA for the given domains.
339        :returns: the certificate and private key PEM file paths
340        """
341        if spec.domains and len(spec.domains):
342            creds = TestCA._make_server_credentials(name=spec.name, domains=spec.domains,
343                                                    issuer=issuer, valid_from=valid_from,
344                                                    valid_to=valid_to, key_type=key_type)
345        elif spec.client:
346            creds = TestCA._make_client_credentials(name=spec.name, issuer=issuer,
347                                                    email=spec.email, valid_from=valid_from,
348                                                    valid_to=valid_to, key_type=key_type)
349        elif spec.name:
350            creds = TestCA._make_ca_credentials(name=spec.name, issuer=issuer,
351                                                valid_from=valid_from, valid_to=valid_to,
352                                                key_type=key_type)
353        else:
354            raise Exception(f"unrecognized certificate specification: {spec}")
355        return creds
356
357    @staticmethod
358    def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name:
359        name_pieces = []
360        if org_name:
361            oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME
362            name_pieces.append(x509.NameAttribute(oid, org_name))
363        elif common_name:
364            name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name))
365        if parent:
366            name_pieces.extend([rdn for rdn in parent])
367        return x509.Name(name_pieces)
368
369    @staticmethod
370    def _make_csr(
371            subject: x509.Name,
372            pkey: Any,
373            issuer_subject: Optional[Credentials],
374            valid_from_delta: timedelta = None,
375            valid_until_delta: timedelta = None
376    ):
377        pubkey = pkey.public_key()
378        issuer_subject = issuer_subject if issuer_subject is not None else subject
379
380        valid_from = datetime.now()
381        if valid_until_delta is not None:
382            valid_from += valid_from_delta
383        valid_until = datetime.now()
384        if valid_until_delta is not None:
385            valid_until += valid_until_delta
386
387        return (
388            x509.CertificateBuilder()
389            .subject_name(subject)
390            .issuer_name(issuer_subject)
391            .public_key(pubkey)
392            .not_valid_before(valid_from)
393            .not_valid_after(valid_until)
394            .serial_number(x509.random_serial_number())
395            .add_extension(
396                x509.SubjectKeyIdentifier.from_public_key(pubkey),
397                critical=False,
398            )
399        )
400
401    @staticmethod
402    def _add_ca_usages(csr: Any) -> Any:
403        return csr.add_extension(
404            x509.BasicConstraints(ca=True, path_length=9),
405            critical=True,
406        ).add_extension(
407            x509.KeyUsage(
408                digital_signature=True,
409                content_commitment=False,
410                key_encipherment=False,
411                data_encipherment=False,
412                key_agreement=False,
413                key_cert_sign=True,
414                crl_sign=True,
415                encipher_only=False,
416                decipher_only=False),
417            critical=True
418        ).add_extension(
419            x509.ExtendedKeyUsage([
420                ExtendedKeyUsageOID.CLIENT_AUTH,
421                ExtendedKeyUsageOID.SERVER_AUTH,
422                ExtendedKeyUsageOID.CODE_SIGNING,
423            ]),
424            critical=True
425        )
426
427    @staticmethod
428    def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any:
429        return csr.add_extension(
430            x509.BasicConstraints(ca=False, path_length=None),
431            critical=True,
432        ).add_extension(
433            x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
434                issuer.certificate.extensions.get_extension_for_class(
435                    x509.SubjectKeyIdentifier).value),
436            critical=False
437        ).add_extension(
438            x509.SubjectAlternativeName([x509.DNSName(domain) for domain in domains]),
439            critical=True,
440        ).add_extension(
441            x509.ExtendedKeyUsage([
442                ExtendedKeyUsageOID.SERVER_AUTH,
443            ]),
444            critical=False
445        )
446
447    @staticmethod
448    def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any:
449        cert = csr.add_extension(
450            x509.BasicConstraints(ca=False, path_length=None),
451            critical=True,
452        ).add_extension(
453            x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
454                issuer.certificate.extensions.get_extension_for_class(
455                    x509.SubjectKeyIdentifier).value),
456            critical=False
457        )
458        if rfc82name:
459            cert.add_extension(
460                x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]),
461                critical=True,
462            )
463        cert.add_extension(
464            x509.ExtendedKeyUsage([
465                ExtendedKeyUsageOID.CLIENT_AUTH,
466            ]),
467            critical=True
468        )
469        return cert
470
471    @staticmethod
472    def _make_ca_credentials(name, key_type: Any,
473                             issuer: Credentials = None,
474                             valid_from: timedelta = timedelta(days=-1),
475                             valid_to: timedelta = timedelta(days=89),
476                             ) -> Credentials:
477        pkey = _private_key(key_type=key_type)
478        if issuer is not None:
479            issuer_subject = issuer.certificate.subject
480            issuer_key = issuer.private_key
481        else:
482            issuer_subject = None
483            issuer_key = pkey
484        subject = TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None)
485        csr = TestCA._make_csr(subject=subject,
486                               issuer_subject=issuer_subject, pkey=pkey,
487                               valid_from_delta=valid_from, valid_until_delta=valid_to)
488        csr = TestCA._add_ca_usages(csr)
489        cert = csr.sign(private_key=issuer_key,
490                        algorithm=hashes.SHA256(),
491                        backend=default_backend())
492        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
493
494    @staticmethod
495    def _make_server_credentials(name: str, domains: List[str], issuer: Credentials,
496                                 key_type: Any,
497                                 valid_from: timedelta = timedelta(days=-1),
498                                 valid_to: timedelta = timedelta(days=89),
499                                 ) -> Credentials:
500        name = name
501        pkey = _private_key(key_type=key_type)
502        subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
503        csr = TestCA._make_csr(subject=subject,
504                               issuer_subject=issuer.certificate.subject, pkey=pkey,
505                               valid_from_delta=valid_from, valid_until_delta=valid_to)
506        csr = TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer)
507        cert = csr.sign(private_key=issuer.private_key,
508                        algorithm=hashes.SHA256(),
509                        backend=default_backend())
510        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
511
512    @staticmethod
513    def _make_client_credentials(name: str,
514                                 issuer: Credentials, email: Optional[str],
515                                 key_type: Any,
516                                 valid_from: timedelta = timedelta(days=-1),
517                                 valid_to: timedelta = timedelta(days=89),
518                                 ) -> Credentials:
519        pkey = _private_key(key_type=key_type)
520        subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
521        csr = TestCA._make_csr(subject=subject,
522                               issuer_subject=issuer.certificate.subject, pkey=pkey,
523                               valid_from_delta=valid_from, valid_until_delta=valid_to)
524        csr = TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email)
525        cert = csr.sign(private_key=issuer.private_key,
526                        algorithm=hashes.SHA256(),
527                        backend=default_backend())
528        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
529