1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
|
import unittest as testm
from binascii import a2b_hex, b2a_hex
from csv import reader
from os import urandom
from os.path import abspath, dirname, sep
import scrypt
class TestScrypt(testm.TestCase):
def setUp(self):
self.input = "message"
self.password = "password"
self.longinput = str(urandom(100000))
self.five_minutes = 300.0
self.five_seconds = 5.0
self.one_byte = 1 # in Bytes
self.one_megabyte = 1024 * 1024 # in Bytes
self.ten_megabytes = 10 * self.one_megabyte
base_dir = dirname(abspath(__file__)) + sep
cvf = open(base_dir + "ciphertexts.csv")
ciphertxt_reader = reader(cvf, dialect="excel")
self.ciphertexts = []
for row in ciphertxt_reader:
self.ciphertexts.append(row)
cvf.close()
self.ciphertext = a2b_hex(bytes(self.ciphertexts[1][5].encode("ascii")))
def test_encrypt_decrypt(self):
"""Test encrypt for simple encryption and decryption."""
s = scrypt.encrypt(self.input, self.password, 0.1)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt(self):
"""Test encrypt takes input and password strings as positional arguments and
produces ciphertext."""
s = scrypt.encrypt(self.input, self.password)
self.assertEqual(len(s), 128 + len(self.input))
def test_encrypt_input_and_password_as_keywords(self):
"""Test encrypt for input and password accepted as keywords."""
s = scrypt.encrypt(password=self.password, input=self.input)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt_missing_input_keyword_argument(self):
"""Test encrypt raises TypeError if keyword argument missing input."""
self.assertRaises(TypeError, lambda: scrypt.encrypt(password=self.password))
def test_encrypt_missing_password_positional_argument(self):
"""Test encrypt raises TypeError if second positional argument missing
(password)"""
self.assertRaises(TypeError, lambda: scrypt.encrypt(self.input))
def test_encrypt_missing_both_required_positional_arguments(self):
"""Test encrypt raises TypeError if both positional arguments missing (input and
password)"""
self.assertRaises(TypeError, lambda: scrypt.encrypt())
def test_encrypt_maxtime_positional(self):
"""Test encrypt maxtime accepts maxtime at position 3."""
s = scrypt.encrypt(self.input, self.password, 0.01)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt_maxtime_key(self):
"""Test encrypt maxtime accepts maxtime as keyword argument."""
s = scrypt.encrypt(self.input, self.password, maxtime=0.01)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt_maxmem_positional(self):
"""Test encrypt maxmem accepts 4th positional argument and exactly (1 megabyte)
of storage to use for V array."""
s = scrypt.encrypt(self.input, self.password, 0.01, self.one_megabyte)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt_maxmem_undersized(self):
"""Test encrypt maxmem accepts (< 1 megabyte) of storage to use for V array."""
s = scrypt.encrypt(self.input, self.password, 0.01, self.one_byte)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt_maxmem_in_normal_range(self):
"""Test encrypt maxmem accepts (> 1 megabyte) of storage to use for V array."""
s = scrypt.encrypt(self.input, self.password, 0.01, self.ten_megabytes)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt_maxmem_keyword_argument(self):
"""Test encrypt maxmem accepts exactly (1 megabyte) of storage to use for V
array."""
s = scrypt.encrypt(
self.input, self.password, maxmem=self.one_megabyte, maxtime=0.01
)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt_maxmemfrac_positional(self):
"""Test encrypt maxmemfrac accepts 5th positional argument of 1/16 total memory
for V array."""
s = scrypt.encrypt(self.input, self.password, 0.01, 0, 0.0625)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt_maxmemfrac_keyword_argument(self):
"""Test encrypt maxmemfrac accepts keyword argument of 1/16 total memory for V
array."""
s = scrypt.encrypt(self.input, self.password, maxmemfrac=0.0625, maxtime=0.01)
m = scrypt.decrypt(s, self.password)
self.assertEqual(m, self.input)
def test_encrypt_long_input(self):
"""Test encrypt accepts long input for encryption."""
s = scrypt.encrypt(self.longinput, self.password, 0.1)
self.assertEqual(len(s), 128 + len(self.longinput))
def test_encrypt_raises_error_on_invalid_keyword(self):
"""Test encrypt raises TypeError if invalid keyword used in argument."""
self.assertRaises(
TypeError,
lambda: scrypt.encrypt(self.input, self.password, nonsense="Raise error"),
)
def test_decrypt_from_csv_ciphertexts(self):
"""Test decrypt function with precalculated combinations."""
for row in self.ciphertexts[1:]:
h = scrypt.decrypt(a2b_hex(bytes(row[5].encode("ascii"))), row[1])
self.assertEqual(bytes(h.encode("ascii")), row[0].encode("ascii"))
def test_decrypt_maxtime_positional(self):
"""Test decrypt function accepts third positional argument."""
m = scrypt.decrypt(self.ciphertext, self.password, self.five_seconds)
self.assertEqual(m, self.input)
def test_decrypt_maxtime_keyword_argument(self):
"""Test decrypt function accepts maxtime keyword argument."""
m = scrypt.decrypt(maxtime=1.0, input=self.ciphertext, password=self.password)
self.assertEqual(m, self.input)
def test_decrypt_maxmem_positional(self):
"""Test decrypt function accepts fourth positional argument."""
m = scrypt.decrypt(
self.ciphertext, self.password, self.five_minutes, self.ten_megabytes
)
self.assertEqual(m, self.input)
def test_decrypt_maxmem_keyword_argument(self):
"""Test decrypt function accepts maxmem keyword argument."""
m = scrypt.decrypt(
maxmem=self.ten_megabytes, input=self.ciphertext, password=self.password
)
self.assertEqual(m, self.input)
def test_decrypt_maxmemfrac_positional(self):
"""Test decrypt function accepts maxmem keyword argument."""
m = scrypt.decrypt(
self.ciphertext, self.password, self.five_minutes, self.one_megabyte, 0.0625
)
self.assertEqual(m, self.input)
def test_decrypt_maxmemfrac_keyword_argument(self):
"""Test decrypt function accepts maxmem keyword argument."""
m = scrypt.decrypt(
maxmemfrac=0.625, input=self.ciphertext, password=self.password
)
self.assertEqual(m, self.input)
def test_decrypt_raises_error_on_too_little_time(self):
"""Test decrypt function raises scrypt.error raised if insufficient time allowed
for ciphertext decryption."""
s = scrypt.encrypt(self.input, self.password, 0.1)
self.assertRaises(scrypt.error, lambda: scrypt.decrypt(s, self.password, 0.01))
class TestScryptHash(testm.TestCase):
def setUp(self):
self.input = "message"
self.password = "password"
self.salt = "NaCl"
self.hashes = []
base_dir = dirname(abspath(__file__)) + sep
hvf = open(base_dir + "hashvectors.csv")
hash_reader = reader(hvf, dialect="excel")
for row in hash_reader:
self.hashes.append(row)
hvf.close()
def test_hash_vectors_from_csv(self):
"""Test hash function with precalculated combinations."""
for row in self.hashes[1:]:
h = scrypt.hash(row[0], row[1], int(row[2]), int(row[3]), int(row[4]))
hhex = b2a_hex(h)
self.assertEqual(hhex, bytes(row[5].encode("utf-8")))
def test_hash_buflen_keyword(self):
"""Test hash takes keyword valid buflen."""
h64 = scrypt.hash(self.input, self.salt, buflen=64)
h128 = scrypt.hash(self.input, self.salt, buflen=128)
self.assertEqual(len(h64), 64)
self.assertEqual(len(h128), 128)
def test_hash_n_positional(self):
"""Test hash accepts valid N in position 3."""
h = scrypt.hash(self.input, self.salt, 256)
self.assertEqual(len(h), 64)
def test_hash_n_keyword(self):
"""Test hash takes keyword valid N."""
h = scrypt.hash(N=256, password=self.input, salt=self.salt)
self.assertEqual(len(h), 64)
def test_hash_r_positional(self):
"""Test hash accepts valid r in position 4."""
h = scrypt.hash(self.input, self.salt, 256, 16)
self.assertEqual(len(h), 64)
def test_hash_r_keyword(self):
"""Test hash takes keyword valid r."""
h = scrypt.hash(r=16, password=self.input, salt=self.salt)
self.assertEqual(len(h), 64)
def test_hash_p_positional(self):
"""Test hash accepts valid p in position 5."""
h = scrypt.hash(self.input, self.salt, 256, 8, 2)
self.assertEqual(len(h), 64)
def test_hash_p_keyword(self):
"""Test hash takes keyword valid p."""
h = scrypt.hash(p=4, password=self.input, salt=self.salt)
self.assertEqual(len(h), 64)
def test_hash_raises_error_on_p_equals_zero(self):
"""Test hash raises scrypt error on illegal parameter value (p = 0)"""
self.assertRaises(scrypt.error, lambda: scrypt.hash(self.input, self.salt, p=0))
def test_hash_raises_error_on_negative_p(self):
"""Test hash raises scrypt error on illegal parameter value (p < 0)"""
self.assertRaises(
scrypt.error, lambda: scrypt.hash(self.input, self.salt, p=-1)
)
def test_hash_raises_error_on_r_equals_zero(self):
"""Test hash raises scrypt error on illegal parameter value (r = 0)"""
self.assertRaises(scrypt.error, lambda: scrypt.hash(self.input, self.salt, r=0))
def test_hash_raises_error_on_negative_r(self):
"""Test hash raises scrypt error on illegal parameter value (r < 1)"""
self.assertRaises(
scrypt.error, lambda: scrypt.hash(self.input, self.salt, r=-1)
)
def test_hash_raises_error_r_p_over_limit(self):
"""Test hash raises scrypt error when parameters r multiplied by p over limit
2**30."""
self.assertRaises(
scrypt.error, lambda: scrypt.hash(self.input, self.salt, r=2, p=2**29)
)
def test_hash_raises_error_n_not_power_of_two(self):
"""Test hash raises scrypt error when parameter N is not a power of two {2, 4,
8, 16, etc}"""
self.assertRaises(scrypt.error, lambda: scrypt.hash(self.input, self.salt, N=3))
def test_hash_raises_error_n_under_limit(self):
"""Test hash raises scrypt error when parameter N under limit of 1."""
self.assertRaises(scrypt.error, lambda: scrypt.hash(self.input, self.salt, N=1))
self.assertRaises(
scrypt.error, lambda: scrypt.hash(self.input, self.salt, N=-1)
)
def test_hash_valid_input(self):
# these parameters must be valid
h = scrypt.hash(b"password", salt=b"salt", N=2, r=8, p=1)
self.assertEqual(len(h), 64)
if __name__ == "__main__":
testm.main()
|