• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import base64
8import struct
9
10import six
11
12from cryptography import utils
13from cryptography.exceptions import UnsupportedAlgorithm
14from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa
15
16
17def load_ssh_public_key(data, backend):
18    key_parts = data.split(b' ', 2)
19
20    if len(key_parts) < 2:
21        raise ValueError(
22            'Key is not in the proper format or contains extra data.')
23
24    key_type = key_parts[0]
25
26    if key_type == b'ssh-rsa':
27        loader = _load_ssh_rsa_public_key
28    elif key_type == b'ssh-dss':
29        loader = _load_ssh_dss_public_key
30    elif key_type in [
31        b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521',
32    ]:
33        loader = _load_ssh_ecdsa_public_key
34    else:
35        raise UnsupportedAlgorithm('Key type is not supported.')
36
37    key_body = key_parts[1]
38
39    try:
40        decoded_data = base64.b64decode(key_body)
41    except TypeError:
42        raise ValueError('Key is not in the proper format.')
43
44    inner_key_type, rest = _ssh_read_next_string(decoded_data)
45
46    if inner_key_type != key_type:
47        raise ValueError(
48            'Key header and key body contain different key type values.'
49        )
50
51    return loader(key_type, rest, backend)
52
53
54def _load_ssh_rsa_public_key(key_type, decoded_data, backend):
55    e, rest = _ssh_read_next_mpint(decoded_data)
56    n, rest = _ssh_read_next_mpint(rest)
57
58    if rest:
59        raise ValueError('Key body contains extra bytes.')
60
61    return rsa.RSAPublicNumbers(e, n).public_key(backend)
62
63
64def _load_ssh_dss_public_key(key_type, decoded_data, backend):
65    p, rest = _ssh_read_next_mpint(decoded_data)
66    q, rest = _ssh_read_next_mpint(rest)
67    g, rest = _ssh_read_next_mpint(rest)
68    y, rest = _ssh_read_next_mpint(rest)
69
70    if rest:
71        raise ValueError('Key body contains extra bytes.')
72
73    parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
74    public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
75
76    return public_numbers.public_key(backend)
77
78
79def _load_ssh_ecdsa_public_key(expected_key_type, decoded_data, backend):
80    curve_name, rest = _ssh_read_next_string(decoded_data)
81    data, rest = _ssh_read_next_string(rest)
82
83    if expected_key_type != b"ecdsa-sha2-" + curve_name:
84        raise ValueError(
85            'Key header and key body contain different key type values.'
86        )
87
88    if rest:
89        raise ValueError('Key body contains extra bytes.')
90
91    curve = {
92        b"nistp256": ec.SECP256R1,
93        b"nistp384": ec.SECP384R1,
94        b"nistp521": ec.SECP521R1,
95    }[curve_name]()
96
97    if six.indexbytes(data, 0) != 4:
98        raise NotImplementedError(
99            "Compressed elliptic curve points are not supported"
100        )
101
102    return ec.EllipticCurvePublicKey.from_encoded_point(curve, data)
103
104
105def _ssh_read_next_string(data):
106    """
107    Retrieves the next RFC 4251 string value from the data.
108
109    While the RFC calls these strings, in Python they are bytes objects.
110    """
111    if len(data) < 4:
112        raise ValueError("Key is not in the proper format")
113
114    str_len, = struct.unpack('>I', data[:4])
115    if len(data) < str_len + 4:
116        raise ValueError("Key is not in the proper format")
117
118    return data[4:4 + str_len], data[4 + str_len:]
119
120
121def _ssh_read_next_mpint(data):
122    """
123    Reads the next mpint from the data.
124
125    Currently, all mpints are interpreted as unsigned.
126    """
127    mpint_data, rest = _ssh_read_next_string(data)
128
129    return (
130        utils.int_from_bytes(mpint_data, byteorder='big', signed=False), rest
131    )
132
133
134def _ssh_write_string(data):
135    return struct.pack(">I", len(data)) + data
136
137
138def _ssh_write_mpint(value):
139    data = utils.int_to_bytes(value)
140    if six.indexbytes(data, 0) & 0x80:
141        data = b"\x00" + data
142    return _ssh_write_string(data)
143