1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
import shutil
import sys
import unittest
from io import BytesIO
from pathlib import Path
from secrets import token_bytes, token_urlsafe
from tempfile import TemporaryDirectory
if sys.version_info >= (3, 14):
from compression import zstd
import tarfile
import zipfile
else:
from backports import zstd
from backports.zstd import register_shutil, tarfile, zipfile
register_shutil()
if sys.version_info >= (3, 11):
from typing import assert_type
else:
assert_type = lambda *_: None
# these tests are simple checks for main use cases
# to make sure they work with the conditional import in 3.14 as well
class TestCompat(unittest.TestCase):
def test_compress_decompress(self) -> None:
raw = token_bytes(1_000)
assert_type(raw, bytes)
compressed = zstd.compress(raw)
assert_type(compressed, bytes)
decompressed = zstd.decompress(compressed)
assert_type(decompressed, bytes)
self.assertEqual(decompressed, raw)
def test_zstdfile(self) -> None:
raw = token_bytes(1_000)
fobj = BytesIO()
with zstd.ZstdFile(fobj, "w") as fzstd:
fzstd.write(raw)
self.assertTrue(fobj.tell() > 0)
fobj.seek(0)
with zstd.ZstdFile(fobj) as fzstd:
data = fzstd.read()
assert_type(data, bytes)
self.assertEqual(data, raw)
self.assertTrue(fobj.tell() > 0)
def test_open(self) -> None:
raw = token_bytes(1_000)
fobj = BytesIO()
with zstd.open(fobj, "w") as fzstd:
fzstd.write(raw)
self.assertTrue(fobj.tell() > 0)
fobj.seek(0)
with zstd.open(fobj) as fzstd:
data = fzstd.read()
assert_type(data, bytes)
self.assertEqual(data, raw)
self.assertTrue(fobj.tell() > 0)
def test_open_binary(self) -> None:
raw = token_bytes(1_000)
fobj = BytesIO()
with zstd.open(fobj, "wb") as fzstd:
fzstd.write(raw)
self.assertTrue(fobj.tell() > 0)
fobj.seek(0)
with zstd.open(fobj, "rb") as fzstd:
data = fzstd.read()
assert_type(data, bytes)
self.assertEqual(data, raw)
self.assertTrue(fobj.tell() > 0)
def test_open_text(self) -> None:
raw = token_urlsafe(1_000)
fobj = BytesIO()
with zstd.open(fobj, "wt") as fzstd:
fzstd.write(raw)
self.assertTrue(fobj.tell() > 0)
fobj.seek(0)
with zstd.open(fobj, "rt") as fzstd:
data = fzstd.read()
assert_type(data, str)
self.assertEqual(data, raw)
self.assertTrue(fobj.tell() > 0)
def test_tarfile(self) -> None:
raw = token_bytes(1_000)
raw_name = token_urlsafe(10)
with TemporaryDirectory() as tmpfile:
path = Path(tmpfile) / "archive.tar.zst"
with tarfile.open(path, "w:zst") as tf:
ti = tarfile.TarInfo(raw_name)
ti.size = len(raw)
tf.addfile(ti, BytesIO(raw))
with tarfile.open(path) as tf:
self.assertEqual(tf.getnames(), [raw_name])
extracted = tf.extractfile(raw_name)
assert extracted is not None # for type checkers
with extracted as fobj:
self.assertEqual(fobj.read(), raw)
shutil.unpack_archive(path, tmpfile)
self.assertEqual((Path(tmpfile) / raw_name).read_bytes(), raw)
def test_zipfile(self) -> None:
raw = token_bytes(1_000)
raw_name = token_urlsafe(10)
with TemporaryDirectory() as tmpfile:
path = Path(tmpfile) / "archive.zip"
with zipfile.ZipFile(path, "w") as zf:
zf.writestr(raw_name, raw, zipfile.ZIP_ZSTANDARD)
with zipfile.ZipFile(path) as zf:
self.assertEqual(zf.namelist(), [raw_name])
self.assertEqual(zf.read(raw_name), raw)
shutil.unpack_archive(path, tmpfile)
self.assertEqual((Path(tmpfile) / raw_name).read_bytes(), raw)
def test_shutil_make_archive(self) -> None:
raw = token_bytes(1_000)
raw_name = token_urlsafe(10)
with TemporaryDirectory() as tmpfile:
path_src = Path(tmpfile) / "src"
path_src.mkdir()
(path_src / raw_name).write_bytes(raw)
path_dst = Path(tmpfile) / "archive"
shutil.make_archive(path_dst.as_posix(), "zstdtar", path_src)
with path_dst.with_suffix(".tar.zst").open("rb") as f:
self.assertEqual(f.read(4), bytes.fromhex("28 b5 2f fd"))
if __name__ == "__main__":
unittest.main()
|