import unittest

from test.support import import_helper, threading_helper
from test.support.threading_helper import run_concurrently

import os
import string
import tempfile
import threading

from collections import Counter

mmap = import_helper.import_module("mmap")

NTHREADS = 10
ANONYMOUS_MEM = -1


@threading_helper.requires_working_threading()
class MmapTests(unittest.TestCase):
    def test_read_and_read_byte(self):
        ascii_uppercase = string.ascii_uppercase.encode()
        # Choose a total mmap size that evenly divides across threads and the
        # read pattern (3 bytes per loop).
        mmap_size = 3 * NTHREADS * len(ascii_uppercase)
        num_bytes_to_read_per_thread = mmap_size // NTHREADS
        bytes_read_from_mmap = []

        def read(mm_obj):
            nread = 0
            while nread < num_bytes_to_read_per_thread:
                b = mm_obj.read_byte()
                bytes_read_from_mmap.append(b)
                b = mm_obj.read(2)
                bytes_read_from_mmap.extend(b)
                nread += 3

        with mmap.mmap(ANONYMOUS_MEM, mmap_size) as mm_obj:
            for i in range(mmap_size // len(ascii_uppercase)):
                mm_obj.write(ascii_uppercase)

            mm_obj.seek(0)
            run_concurrently(
                worker_func=read,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )

        self.assertEqual(len(bytes_read_from_mmap), mmap_size)
        # Count each letter/byte to verify read correctness
        counter = Counter(bytes_read_from_mmap)
        self.assertEqual(len(counter), len(ascii_uppercase))
        # Each letter/byte should be read (3 * NTHREADS) times
        for letter in ascii_uppercase:
            self.assertEqual(counter[letter], 3 * NTHREADS)

    def test_readline(self):
        num_lines = 1000
        lines_read_from_mmap = []
        expected_lines = []

        def readline(mm_obj):
            for i in range(num_lines // NTHREADS):
                line = mm_obj.readline()
                lines_read_from_mmap.append(line)

        # Allocate mmap enough for num_lines (max line 5 bytes including NL)
        with mmap.mmap(ANONYMOUS_MEM, num_lines * 5) as mm_obj:
            for i in range(num_lines):
                line = b"%d\n" % i
                mm_obj.write(line)
                expected_lines.append(line)

            mm_obj.seek(0)
            run_concurrently(
                worker_func=readline,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )

        self.assertEqual(len(lines_read_from_mmap), num_lines)
        # Every line should be read once by threads; order is non-deterministic
        # Sort numerically by integer value
        lines_read_from_mmap.sort(key=lambda x: int(x))
        self.assertEqual(lines_read_from_mmap, expected_lines)

    def test_write_and_write_byte(self):
        thread_letters = list(string.ascii_uppercase)
        self.assertLessEqual(NTHREADS, len(thread_letters))
        per_thread_write_loop = 100

        def write(mm_obj):
            # Each thread picks a unique letter to write
            thread_letter = thread_letters.pop(0)
            thread_bytes = (thread_letter * 2).encode()
            for _ in range(per_thread_write_loop):
                mm_obj.write_byte(thread_bytes[0])
                mm_obj.write(thread_bytes)

        with mmap.mmap(
            ANONYMOUS_MEM, per_thread_write_loop * 3 * NTHREADS
        ) as mm_obj:
            run_concurrently(
                worker_func=write,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )
            mm_obj.seek(0)
            data = mm_obj.read()
            self.assertEqual(len(data), NTHREADS * per_thread_write_loop * 3)
            counter = Counter(data)
            self.assertEqual(len(counter), NTHREADS)
            # Each thread letter should be written `per_thread_write_loop` * 3
            for letter in counter:
                self.assertEqual(counter[letter], per_thread_write_loop * 3)

    def test_move(self):
        ascii_uppercase = string.ascii_uppercase.encode()
        num_letters = len(ascii_uppercase)

        def move(mm_obj):
            for i in range(num_letters):
                # Move 1 byte from the first half to the second half
                mm_obj.move(0 + i, num_letters + i, 1)

        with mmap.mmap(ANONYMOUS_MEM, 2 * num_letters) as mm_obj:
            mm_obj.write(ascii_uppercase)
            run_concurrently(
                worker_func=move,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )

    def test_seek_and_tell(self):
        seek_per_thread = 10

        def seek(mm_obj):
            self.assertTrue(mm_obj.seekable())
            for _ in range(seek_per_thread):
                before_seek = mm_obj.tell()
                mm_obj.seek(1, os.SEEK_CUR)
                self.assertLess(before_seek, mm_obj.tell())

        with mmap.mmap(ANONYMOUS_MEM, 1024) as mm_obj:
            run_concurrently(
                worker_func=seek,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )
            # Each thread seeks from current position, the end position should
            # be the sum of all seeks from all threads.
            self.assertEqual(mm_obj.tell(), NTHREADS * seek_per_thread)

    def test_slice_update_and_slice_read(self):
        thread_letters = list(string.ascii_uppercase)
        self.assertLessEqual(NTHREADS, len(thread_letters))

        def slice_update_and_slice_read(mm_obj):
            # Each thread picks a unique letter to write
            thread_letter = thread_letters.pop(0)
            thread_bytes = (thread_letter * 1024).encode()
            for _ in range(100):
                mm_obj[:] = thread_bytes
                read_bytes = mm_obj[:]
                # Read bytes should be all the same letter, showing no
                # interleaving
                self.assertTrue(all_same(read_bytes))

        with mmap.mmap(ANONYMOUS_MEM, 1024) as mm_obj:
            run_concurrently(
                worker_func=slice_update_and_slice_read,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )

    def test_item_update_and_item_read(self):
        thread_indexes = [i for i in range(NTHREADS)]

        def item_update_and_item_read(mm_obj):
            # Each thread picks a unique index to write
            thread_index = thread_indexes.pop()
            for i in range(100):
                mm_obj[thread_index] = i
                self.assertEqual(mm_obj[thread_index], i)

            # Read values set by other threads, all values
            # should be less than '100'
            for val in mm_obj:
                self.assertLess(int.from_bytes(val), 100)

        with mmap.mmap(ANONYMOUS_MEM, NTHREADS + 1) as mm_obj:
            run_concurrently(
                worker_func=item_update_and_item_read,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )

    def test_close_and_closed(self):
        def close_mmap(mm_obj):
            mm_obj.close()
            self.assertTrue(mm_obj.closed)

        with mmap.mmap(ANONYMOUS_MEM, 1) as mm_obj:
            run_concurrently(
                worker_func=close_mmap,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )

    def test_find_and_rfind(self):
        per_thread_loop = 10

        def find_and_rfind(mm_obj):
            pattern = b'Thread-Ident:"%d"' % threading.get_ident()
            mm_obj.write(pattern)
            for _ in range(per_thread_loop):
                found_at = mm_obj.find(pattern, 0)
                self.assertNotEqual(found_at, -1)
                # Should not find it after the `found_at`
                self.assertEqual(mm_obj.find(pattern, found_at + 1), -1)
                found_at_rev = mm_obj.rfind(pattern, 0)
                self.assertEqual(found_at, found_at_rev)
                # Should not find it after the `found_at`
                self.assertEqual(mm_obj.rfind(pattern, found_at + 1), -1)

        with mmap.mmap(ANONYMOUS_MEM, 1024) as mm_obj:
            run_concurrently(
                worker_func=find_and_rfind,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )

    def test_mmap_export_as_memoryview(self):
        """
        Each thread creates a memoryview and updates the internal state of the
        mmap object.
        """
        buffer_size = 42

        def create_memoryview_from_mmap(mm_obj):
            memoryviews = []
            for _ in range(100):
                mv = memoryview(mm_obj)
                memoryviews.append(mv)
                self.assertEqual(len(mv), buffer_size)
                self.assertEqual(mv[:7], b"CPython")

            # Cannot close the mmap while it is exported as buffers
            with self.assertRaisesRegex(
                BufferError, "cannot close exported pointers exist"
            ):
                mm_obj.close()

        with mmap.mmap(ANONYMOUS_MEM, 42) as mm_obj:
            mm_obj.write(b"CPython")
            run_concurrently(
                worker_func=create_memoryview_from_mmap,
                args=(mm_obj,),
                nthreads=NTHREADS,
            )
            # Implicit mm_obj.close() verifies all exports (memoryviews) are
            # properly freed.


def all_same(lst):
    return all(item == lst[0] for item in lst)


if __name__ == "__main__":
    unittest.main()
