1# This file is dual licensed under the terms of the Apache License, Version 2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 3# for complete details. 4 5from __future__ import absolute_import, division, print_function 6 7import base64 8import struct 9 10import six 11 12from cryptography import utils 13from cryptography.exceptions import UnsupportedAlgorithm 14from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa 15 16 17def load_ssh_public_key(data, backend): 18 key_parts = data.split(b' ', 2) 19 20 if len(key_parts) < 2: 21 raise ValueError( 22 'Key is not in the proper format or contains extra data.') 23 24 key_type = key_parts[0] 25 26 if key_type == b'ssh-rsa': 27 loader = _load_ssh_rsa_public_key 28 elif key_type == b'ssh-dss': 29 loader = _load_ssh_dss_public_key 30 elif key_type in [ 31 b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521', 32 ]: 33 loader = _load_ssh_ecdsa_public_key 34 else: 35 raise UnsupportedAlgorithm('Key type is not supported.') 36 37 key_body = key_parts[1] 38 39 try: 40 decoded_data = base64.b64decode(key_body) 41 except TypeError: 42 raise ValueError('Key is not in the proper format.') 43 44 inner_key_type, rest = _ssh_read_next_string(decoded_data) 45 46 if inner_key_type != key_type: 47 raise ValueError( 48 'Key header and key body contain different key type values.' 49 ) 50 51 return loader(key_type, rest, backend) 52 53 54def _load_ssh_rsa_public_key(key_type, decoded_data, backend): 55 e, rest = _ssh_read_next_mpint(decoded_data) 56 n, rest = _ssh_read_next_mpint(rest) 57 58 if rest: 59 raise ValueError('Key body contains extra bytes.') 60 61 return rsa.RSAPublicNumbers(e, n).public_key(backend) 62 63 64def _load_ssh_dss_public_key(key_type, decoded_data, backend): 65 p, rest = _ssh_read_next_mpint(decoded_data) 66 q, rest = _ssh_read_next_mpint(rest) 67 g, rest = _ssh_read_next_mpint(rest) 68 y, rest = _ssh_read_next_mpint(rest) 69 70 if rest: 71 raise ValueError('Key body contains extra bytes.') 72 73 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 74 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 75 76 return public_numbers.public_key(backend) 77 78 79def _load_ssh_ecdsa_public_key(expected_key_type, decoded_data, backend): 80 curve_name, rest = _ssh_read_next_string(decoded_data) 81 data, rest = _ssh_read_next_string(rest) 82 83 if expected_key_type != b"ecdsa-sha2-" + curve_name: 84 raise ValueError( 85 'Key header and key body contain different key type values.' 86 ) 87 88 if rest: 89 raise ValueError('Key body contains extra bytes.') 90 91 curve = { 92 b"nistp256": ec.SECP256R1, 93 b"nistp384": ec.SECP384R1, 94 b"nistp521": ec.SECP521R1, 95 }[curve_name]() 96 97 if six.indexbytes(data, 0) != 4: 98 raise NotImplementedError( 99 "Compressed elliptic curve points are not supported" 100 ) 101 102 return ec.EllipticCurvePublicKey.from_encoded_point(curve, data) 103 104 105def _ssh_read_next_string(data): 106 """ 107 Retrieves the next RFC 4251 string value from the data. 108 109 While the RFC calls these strings, in Python they are bytes objects. 110 """ 111 if len(data) < 4: 112 raise ValueError("Key is not in the proper format") 113 114 str_len, = struct.unpack('>I', data[:4]) 115 if len(data) < str_len + 4: 116 raise ValueError("Key is not in the proper format") 117 118 return data[4:4 + str_len], data[4 + str_len:] 119 120 121def _ssh_read_next_mpint(data): 122 """ 123 Reads the next mpint from the data. 124 125 Currently, all mpints are interpreted as unsigned. 126 """ 127 mpint_data, rest = _ssh_read_next_string(data) 128 129 return ( 130 utils.int_from_bytes(mpint_data, byteorder='big', signed=False), rest 131 ) 132 133 134def _ssh_write_string(data): 135 return struct.pack(">I", len(data)) + data 136 137 138def _ssh_write_mpint(value): 139 data = utils.int_to_bytes(value) 140 if six.indexbytes(data, 0) & 0x80: 141 data = b"\x00" + data 142 return _ssh_write_string(data) 143