import io
import os
import unittest

try:
    import hypothesis
    import hypothesis.strategies as strategies
except ImportError:
    raise unittest.SkipTest("hypothesis not available")

import zstandard as zstd

from .common import random_input_data


@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
class TestDecompressor_stream_reader_fuzzing(unittest.TestCase):
    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        streaming=strategies.booleans(),
        source_read_size=strategies.integers(1, 1048576),
        read_sizes=strategies.data(),
    )
    def test_stream_source_read_variance(
        self, original, level, streaming, source_read_size, read_sizes
    ):
        cctx = zstd.ZstdCompressor(level=level)

        if streaming:
            source = io.BytesIO()
            writer = cctx.stream_writer(source)
            writer.write(original)
            writer.flush(zstd.FLUSH_FRAME)
            source.seek(0)
        else:
            frame = cctx.compress(original)
            source = io.BytesIO(frame)

        dctx = zstd.ZstdDecompressor()

        chunks = []
        with dctx.stream_reader(source, read_size=source_read_size) as reader:
            while True:
                read_size = read_sizes.draw(strategies.integers(-1, 131072))
                chunk = reader.read(read_size)
                if not chunk and read_size:
                    break

                chunks.append(chunk)

        self.assertEqual(b"".join(chunks), original)

    # Similar to above except we have a constant read() size.
    @hypothesis.settings(
        suppress_health_check=[hypothesis.HealthCheck.large_base_example]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        streaming=strategies.booleans(),
        source_read_size=strategies.integers(1, 1048576),
        read_size=strategies.integers(-1, 131072),
    )
    def test_stream_source_read_size(
        self, original, level, streaming, source_read_size, read_size
    ):
        if read_size == 0:
            read_size = 1

        cctx = zstd.ZstdCompressor(level=level)

        if streaming:
            source = io.BytesIO()
            writer = cctx.stream_writer(source)
            writer.write(original)
            writer.flush(zstd.FLUSH_FRAME)
            source.seek(0)
        else:
            frame = cctx.compress(original)
            source = io.BytesIO(frame)

        dctx = zstd.ZstdDecompressor()

        chunks = []
        reader = dctx.stream_reader(source, read_size=source_read_size)
        while True:
            chunk = reader.read(read_size)
            if not chunk and read_size:
                break

            chunks.append(chunk)

        self.assertEqual(b"".join(chunks), original)

    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        streaming=strategies.booleans(),
        source_read_size=strategies.integers(1, 1048576),
        read_sizes=strategies.data(),
    )
    def test_buffer_source_read_variance(
        self, original, level, streaming, source_read_size, read_sizes
    ):
        cctx = zstd.ZstdCompressor(level=level)

        if streaming:
            source = io.BytesIO()
            writer = cctx.stream_writer(source)
            writer.write(original)
            writer.flush(zstd.FLUSH_FRAME)
            frame = source.getvalue()
        else:
            frame = cctx.compress(original)

        dctx = zstd.ZstdDecompressor()
        chunks = []

        with dctx.stream_reader(frame, read_size=source_read_size) as reader:
            while True:
                read_size = read_sizes.draw(strategies.integers(-1, 131072))
                chunk = reader.read(read_size)
                if not chunk and read_size:
                    break

                chunks.append(chunk)

        self.assertEqual(b"".join(chunks), original)

    # Similar to above except we have a constant read() size.
    @hypothesis.settings(
        suppress_health_check=[hypothesis.HealthCheck.large_base_example]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        streaming=strategies.booleans(),
        source_read_size=strategies.integers(1, 1048576),
        read_size=strategies.integers(-1, 131072),
    )
    def test_buffer_source_constant_read_size(
        self, original, level, streaming, source_read_size, read_size
    ):
        if read_size == 0:
            read_size = -1

        cctx = zstd.ZstdCompressor(level=level)

        if streaming:
            source = io.BytesIO()
            writer = cctx.stream_writer(source)
            writer.write(original)
            writer.flush(zstd.FLUSH_FRAME)
            frame = source.getvalue()
        else:
            frame = cctx.compress(original)

        dctx = zstd.ZstdDecompressor()
        chunks = []

        reader = dctx.stream_reader(frame, read_size=source_read_size)
        while True:
            chunk = reader.read(read_size)
            if not chunk and read_size:
                break

            chunks.append(chunk)

        self.assertEqual(b"".join(chunks), original)

    @hypothesis.settings(
        suppress_health_check=[hypothesis.HealthCheck.large_base_example]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        streaming=strategies.booleans(),
        source_read_size=strategies.integers(1, 1048576),
    )
    def test_stream_source_readall(
        self, original, level, streaming, source_read_size
    ):
        cctx = zstd.ZstdCompressor(level=level)

        if streaming:
            source = io.BytesIO()
            writer = cctx.stream_writer(source)
            writer.write(original)
            writer.flush(zstd.FLUSH_FRAME)
            source.seek(0)
        else:
            frame = cctx.compress(original)
            source = io.BytesIO(frame)

        dctx = zstd.ZstdDecompressor()

        data = dctx.stream_reader(source, read_size=source_read_size).readall()
        self.assertEqual(data, original)

    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        streaming=strategies.booleans(),
        source_read_size=strategies.integers(1, 1048576),
        read_sizes=strategies.data(),
    )
    def test_stream_source_read1_variance(
        self, original, level, streaming, source_read_size, read_sizes
    ):
        cctx = zstd.ZstdCompressor(level=level)

        if streaming:
            source = io.BytesIO()
            writer = cctx.stream_writer(source)
            writer.write(original)
            writer.flush(zstd.FLUSH_FRAME)
            source.seek(0)
        else:
            frame = cctx.compress(original)
            source = io.BytesIO(frame)

        dctx = zstd.ZstdDecompressor()

        chunks = []
        with dctx.stream_reader(source, read_size=source_read_size) as reader:
            while True:
                read_size = read_sizes.draw(strategies.integers(-1, 131072))
                chunk = reader.read1(read_size)
                if not chunk and read_size:
                    break

                chunks.append(chunk)

        self.assertEqual(b"".join(chunks), original)

    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        streaming=strategies.booleans(),
        source_read_size=strategies.integers(1, 1048576),
        read_sizes=strategies.data(),
    )
    def test_stream_source_readinto1_variance(
        self, original, level, streaming, source_read_size, read_sizes
    ):
        cctx = zstd.ZstdCompressor(level=level)

        if streaming:
            source = io.BytesIO()
            writer = cctx.stream_writer(source)
            writer.write(original)
            writer.flush(zstd.FLUSH_FRAME)
            source.seek(0)
        else:
            frame = cctx.compress(original)
            source = io.BytesIO(frame)

        dctx = zstd.ZstdDecompressor()

        chunks = []
        with dctx.stream_reader(source, read_size=source_read_size) as reader:
            while True:
                read_size = read_sizes.draw(strategies.integers(1, 131072))
                b = bytearray(read_size)
                count = reader.readinto1(b)

                if not count:
                    break

                chunks.append(bytes(b[0:count]))

        self.assertEqual(b"".join(chunks), original)

    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.data_too_large,
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        source_read_size=strategies.integers(1, 1048576),
        seek_amounts=strategies.data(),
        read_sizes=strategies.data(),
    )
    def test_relative_seeks(
        self, original, level, source_read_size, seek_amounts, read_sizes
    ):
        cctx = zstd.ZstdCompressor(level=level)
        frame = cctx.compress(original)

        dctx = zstd.ZstdDecompressor()

        with dctx.stream_reader(frame, read_size=source_read_size) as reader:
            while True:
                amount = seek_amounts.draw(strategies.integers(0, 16384))
                reader.seek(amount, os.SEEK_CUR)

                offset = reader.tell()
                read_amount = read_sizes.draw(strategies.integers(1, 16384))
                chunk = reader.read(read_amount)

                if not chunk:
                    break

                self.assertEqual(original[offset : offset + len(chunk)], chunk)

    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        chunks=strategies.lists(
            strategies.sampled_from(random_input_data()),
            min_size=2,
            max_size=10,
        ),
        level=strategies.integers(min_value=1, max_value=5),
        source_read_size=strategies.integers(1, 1048576),
        read_sizes=strategies.data(),
    )
    def test_multiple_frames(self, chunks, level, source_read_size, read_sizes):
        cctx = zstd.ZstdCompressor(level=level)
        source = io.BytesIO()
        buffer = io.BytesIO()
        writer = cctx.stream_writer(buffer)

        for chunk in chunks:
            source.write(chunk)
            writer.write(chunk)
            writer.flush(zstd.FLUSH_FRAME)

        dctx = zstd.ZstdDecompressor()
        buffer.seek(0)
        reader = dctx.stream_reader(
            buffer, read_size=source_read_size, read_across_frames=True
        )

        chunks = []

        while True:
            read_amount = read_sizes.draw(strategies.integers(-1, 16384))
            chunk = reader.read(read_amount)

            if not chunk and read_amount:
                break

            chunks.append(chunk)

        self.assertEqual(source.getvalue(), b"".join(chunks))


@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
class TestDecompressor_stream_writer_fuzzing(unittest.TestCase):
    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        write_size=strategies.integers(min_value=1, max_value=8192),
        input_sizes=strategies.data(),
    )
    def test_write_size_variance(
        self, original, level, write_size, input_sizes
    ):
        cctx = zstd.ZstdCompressor(level=level)
        frame = cctx.compress(original)

        dctx = zstd.ZstdDecompressor()
        source = io.BytesIO(frame)
        dest = io.BytesIO()

        with dctx.stream_writer(
            dest, write_size=write_size, closefd=False
        ) as decompressor:
            while True:
                input_size = input_sizes.draw(strategies.integers(1, 4096))
                chunk = source.read(input_size)
                if not chunk:
                    break

                decompressor.write(chunk)

        self.assertEqual(dest.getvalue(), original)


@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
class TestDecompressor_copy_stream_fuzzing(unittest.TestCase):
    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        read_size=strategies.integers(min_value=1, max_value=8192),
        write_size=strategies.integers(min_value=1, max_value=8192),
    )
    def test_read_write_size_variance(
        self, original, level, read_size, write_size
    ):
        cctx = zstd.ZstdCompressor(level=level)
        frame = cctx.compress(original)

        source = io.BytesIO(frame)
        dest = io.BytesIO()

        dctx = zstd.ZstdDecompressor()
        dctx.copy_stream(
            source, dest, read_size=read_size, write_size=write_size
        )

        self.assertEqual(dest.getvalue(), original)


@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
class TestDecompressor_decompressobj_fuzzing(unittest.TestCase):
    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        chunk_sizes=strategies.data(),
    )
    def test_random_input_sizes(self, original, level, chunk_sizes):
        cctx = zstd.ZstdCompressor(level=level)
        frame = cctx.compress(original)

        source = io.BytesIO(frame)

        dctx = zstd.ZstdDecompressor()
        dobj = dctx.decompressobj()

        chunks = []
        while True:
            chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
            chunk = source.read(chunk_size)
            if not chunk:
                break

            chunks.append(dobj.decompress(chunk))

        self.assertEqual(b"".join(chunks), original)

    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
            hypothesis.HealthCheck.too_slow,
        ]
    )
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        write_size=strategies.integers(
            min_value=1,
            max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
        ),
        chunk_sizes=strategies.data(),
    )
    def test_random_output_sizes(
        self, original, level, write_size, chunk_sizes
    ):
        cctx = zstd.ZstdCompressor(level=level)
        frame = cctx.compress(original)

        source = io.BytesIO(frame)

        dctx = zstd.ZstdDecompressor()
        dobj = dctx.decompressobj(write_size=write_size)

        chunks = []
        while True:
            chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
            chunk = source.read(chunk_size)
            if not chunk:
                break

            chunks.append(dobj.decompress(chunk))

        self.assertEqual(b"".join(chunks), original)

    @hypothesis.given(
        chunks=strategies.lists(
            strategies.sampled_from(random_input_data()),
            min_size=2,
            max_size=10,
        ),
        level=strategies.integers(min_value=1, max_value=5),
        write_size=strategies.integers(
            min_value=1,
            max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
        ),
        read_sizes=strategies.data(),
    )
    def test_read_across_frames_false(
        self, chunks, level, write_size, read_sizes
    ):
        cctx = zstd.ZstdCompressor(level=level)

        source = io.BytesIO()
        source_chunks = []
        compressed = io.BytesIO()

        for chunk in chunks:
            source.write(chunk)
            source_chunks.append(chunk)
            compressed.write(cctx.compress(chunk))

        compressed.seek(0)

        dctx = zstd.ZstdDecompressor()
        dobj = dctx.decompressobj(
            write_size=write_size, read_across_frames=False
        )

        decompressed = io.BytesIO()

        while True:
            read_size = read_sizes.draw(strategies.integers(1, 4096))
            chunk = compressed.read(read_size)
            if not chunk:
                break

            try:
                decompressed.write(dobj.decompress(chunk))
            except zstd.ZstdError as e:
                if e.args[0] == "cannot use a decompressobj multiple times":
                    break
                else:
                    raise

        self.assertEqual(decompressed.getvalue(), source_chunks[0])

    @hypothesis.settings(
        suppress_health_check=[
            hypothesis.HealthCheck.large_base_example,
        ]
    )
    @hypothesis.given(
        chunks=strategies.lists(
            strategies.sampled_from(random_input_data()),
            min_size=2,
            max_size=10,
        ),
        level=strategies.integers(min_value=1, max_value=5),
        write_size=strategies.integers(
            min_value=1,
            max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
        ),
        read_sizes=strategies.data(),
    )
    def test_read_across_frames_true(
        self, chunks, level, write_size, read_sizes
    ):
        cctx = zstd.ZstdCompressor(level=level)

        source = io.BytesIO()
        source_chunks = []
        compressed = io.BytesIO()

        for chunk in chunks:
            source.write(chunk)
            source_chunks.append(chunk)
            compressed.write(cctx.compress(chunk))

        compressed.seek(0)

        dctx = zstd.ZstdDecompressor()
        dobj = dctx.decompressobj(
            write_size=write_size, read_across_frames=True
        )

        decompressed = io.BytesIO()

        while True:
            read_size = read_sizes.draw(strategies.integers(1, 4096))
            chunk = compressed.read(read_size)
            if not chunk:
                break

            decompressed.write(dobj.decompress(chunk))

        self.assertEqual(decompressed.getvalue(), source.getvalue())


@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
class TestDecompressor_read_to_iter_fuzzing(unittest.TestCase):
    @hypothesis.given(
        original=strategies.sampled_from(random_input_data()),
        level=strategies.integers(min_value=1, max_value=5),
        read_size=strategies.integers(min_value=1, max_value=4096),
        write_size=strategies.integers(min_value=1, max_value=4096),
    )
    def test_read_write_size_variance(
        self, original, level, read_size, write_size
    ):
        cctx = zstd.ZstdCompressor(level=level)
        frame = cctx.compress(original)

        source = io.BytesIO(frame)

        dctx = zstd.ZstdDecompressor()
        chunks = list(
            dctx.read_to_iter(
                source, read_size=read_size, write_size=write_size
            )
        )

        self.assertEqual(b"".join(chunks), original)


@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
@unittest.skipUnless(
    "multi_decompress_to_buffer" in zstd.backend_features,
    "multi_decompress_to_buffer not available",
)
class TestDecompressor_multi_decompress_to_buffer_fuzzing(unittest.TestCase):
    @hypothesis.given(
        original=strategies.lists(
            strategies.sampled_from(random_input_data()),
            min_size=1,
            max_size=1024,
        ),
        threads=strategies.integers(min_value=1, max_value=8),
        use_dict=strategies.booleans(),
    )
    def test_data_equivalence(self, original, threads, use_dict):
        kwargs = {}
        if use_dict:
            kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0])

        cctx = zstd.ZstdCompressor(
            level=1, write_content_size=True, write_checksum=True, **kwargs
        )

        frames_buffer = cctx.multi_compress_to_buffer(original, threads=-1)

        dctx = zstd.ZstdDecompressor(**kwargs)
        result = dctx.multi_decompress_to_buffer(frames_buffer)

        self.assertEqual(len(result), len(original))
        for i, frame in enumerate(result):
            self.assertEqual(frame.tobytes(), original[i])

        frames_list = [f.tobytes() for f in frames_buffer]
        result = dctx.multi_decompress_to_buffer(frames_list)

        self.assertEqual(len(result), len(original))
        for i, frame in enumerate(result):
            self.assertEqual(frame.tobytes(), original[i])
