• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import binascii
8import os
9
10import pytest
11
12from cryptography.exceptions import (
13    AlreadyFinalized,
14    InvalidKey,
15    UnsupportedAlgorithm,
16)
17from cryptography.hazmat.backends.interfaces import ScryptBackend
18from cryptography.hazmat.primitives.kdf.scrypt import Scrypt, _MEM_LIMIT
19
20from tests.utils import load_nist_vectors, load_vectors_from_file
21
22vectors = load_vectors_from_file(
23    os.path.join("KDF", "scrypt.txt"), load_nist_vectors
24)
25
26
27def _skip_if_memory_limited(memory_limit, params):
28    # Memory calc adapted from OpenSSL (URL split over 2 lines, thanks PEP8)
29    # https://github.com/openssl/openssl/blob/6286757141a8c6e14d647ec733634a
30    # e0c83d9887/crypto/evp/scrypt.c#L189-L221
31    blen = int(params["p"]) * 128 * int(params["r"])
32    vlen = 32 * int(params["r"]) * (int(params["n"]) + 2) * 4
33    memory_required = blen + vlen
34    if memory_limit < memory_required:
35        pytest.skip(
36            "Test exceeds Scrypt memory limit. "
37            "This is likely a 32-bit platform."
38        )
39
40
41def test_memory_limit_skip():
42    with pytest.raises(pytest.skip.Exception):
43        _skip_if_memory_limited(1000, {"p": 16, "r": 64, "n": 1024})
44
45    _skip_if_memory_limited(2 ** 31, {"p": 16, "r": 64, "n": 1024})
46
47
48@pytest.mark.requires_backend_interface(interface=ScryptBackend)
49class TestScrypt(object):
50    @pytest.mark.parametrize("params", vectors)
51    def test_derive(self, backend, params):
52        _skip_if_memory_limited(_MEM_LIMIT, params)
53        password = params["password"]
54        work_factor = int(params["n"])
55        block_size = int(params["r"])
56        parallelization_factor = int(params["p"])
57        length = int(params["length"])
58        salt = params["salt"]
59        derived_key = params["derived_key"]
60
61        scrypt = Scrypt(
62            salt,
63            length,
64            work_factor,
65            block_size,
66            parallelization_factor,
67            backend,
68        )
69        assert binascii.hexlify(scrypt.derive(password)) == derived_key
70
71    def test_unsupported_backend(self):
72        work_factor = 1024
73        block_size = 8
74        parallelization_factor = 16
75        length = 64
76        salt = b"NaCl"
77        backend = object()
78
79        with pytest.raises(UnsupportedAlgorithm):
80            Scrypt(
81                salt,
82                length,
83                work_factor,
84                block_size,
85                parallelization_factor,
86                backend,
87            )
88
89    def test_salt_not_bytes(self, backend):
90        work_factor = 1024
91        block_size = 8
92        parallelization_factor = 16
93        length = 64
94        salt = 1
95
96        with pytest.raises(TypeError):
97            Scrypt(
98                salt,
99                length,
100                work_factor,
101                block_size,
102                parallelization_factor,
103                backend,
104            )
105
106    def test_scrypt_malloc_failure(self, backend):
107        password = b"NaCl"
108        work_factor = 1024 ** 3
109        block_size = 589824
110        parallelization_factor = 16
111        length = 64
112        salt = b"NaCl"
113
114        scrypt = Scrypt(
115            salt,
116            length,
117            work_factor,
118            block_size,
119            parallelization_factor,
120            backend,
121        )
122
123        with pytest.raises(MemoryError):
124            scrypt.derive(password)
125
126    def test_password_not_bytes(self, backend):
127        password = 1
128        work_factor = 1024
129        block_size = 8
130        parallelization_factor = 16
131        length = 64
132        salt = b"NaCl"
133
134        scrypt = Scrypt(
135            salt,
136            length,
137            work_factor,
138            block_size,
139            parallelization_factor,
140            backend,
141        )
142
143        with pytest.raises(TypeError):
144            scrypt.derive(password)
145
146    def test_buffer_protocol(self, backend):
147        password = bytearray(b"password")
148        work_factor = 256
149        block_size = 8
150        parallelization_factor = 16
151        length = 10
152        salt = b"NaCl"
153
154        scrypt = Scrypt(
155            salt,
156            length,
157            work_factor,
158            block_size,
159            parallelization_factor,
160            backend,
161        )
162
163        assert scrypt.derive(password) == b"\xf4\x92\x86\xb2\x06\x0c\x848W\x87"
164
165    @pytest.mark.parametrize("params", vectors)
166    def test_verify(self, backend, params):
167        _skip_if_memory_limited(_MEM_LIMIT, params)
168        password = params["password"]
169        work_factor = int(params["n"])
170        block_size = int(params["r"])
171        parallelization_factor = int(params["p"])
172        length = int(params["length"])
173        salt = params["salt"]
174        derived_key = params["derived_key"]
175
176        scrypt = Scrypt(
177            salt,
178            length,
179            work_factor,
180            block_size,
181            parallelization_factor,
182            backend,
183        )
184        assert scrypt.verify(password, binascii.unhexlify(derived_key)) is None
185
186    def test_invalid_verify(self, backend):
187        password = b"password"
188        work_factor = 1024
189        block_size = 8
190        parallelization_factor = 16
191        length = 64
192        salt = b"NaCl"
193        derived_key = b"fdbabe1c9d3472007856e7190d01e9fe7c6ad7cbc8237830e773"
194
195        scrypt = Scrypt(
196            salt,
197            length,
198            work_factor,
199            block_size,
200            parallelization_factor,
201            backend,
202        )
203
204        with pytest.raises(InvalidKey):
205            scrypt.verify(password, binascii.unhexlify(derived_key))
206
207    def test_already_finalized(self, backend):
208        password = b"password"
209        work_factor = 1024
210        block_size = 8
211        parallelization_factor = 16
212        length = 64
213        salt = b"NaCl"
214
215        scrypt = Scrypt(
216            salt,
217            length,
218            work_factor,
219            block_size,
220            parallelization_factor,
221            backend,
222        )
223        scrypt.derive(password)
224        with pytest.raises(AlreadyFinalized):
225            scrypt.derive(password)
226
227    def test_invalid_n(self, backend):
228        # n is less than 2
229        with pytest.raises(ValueError):
230            Scrypt(b"NaCl", 64, 1, 8, 16, backend)
231
232        # n is not a power of 2
233        with pytest.raises(ValueError):
234            Scrypt(b"NaCl", 64, 3, 8, 16, backend)
235
236    def test_invalid_r(self, backend):
237        with pytest.raises(ValueError):
238            Scrypt(b"NaCl", 64, 2, 0, 16, backend)
239
240    def test_invalid_p(self, backend):
241        with pytest.raises(ValueError):
242            Scrypt(b"NaCl", 64, 2, 8, 0, backend)
243