# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Most of this work is copyright (C) 2013-2020 David R. MacIver
# (david@drmaciver.com), but it contains contributions by others. See
# CONTRIBUTING.rst for a full list of people who may hold copyright, and
# consult the git log if you need to determine who owns an individual
# contribution.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
#
# END HEADER

from hypothesis.errors import InvalidArgument
from hypothesis.internal import charmap
from hypothesis.internal.conjecture.utils import biased_coin, integer_range
from hypothesis.internal.intervalsets import IntervalSet
from hypothesis.strategies._internal.strategies import SearchStrategy


class OneCharStringStrategy(SearchStrategy):
    """A strategy which generates single character strings of text type."""

    def __init__(
        self,
        whitelist_categories=None,
        blacklist_categories=None,
        blacklist_characters=None,
        min_codepoint=None,
        max_codepoint=None,
        whitelist_characters=None,
    ):
        assert set(whitelist_categories or ()).issubset(charmap.categories())
        assert set(blacklist_categories or ()).issubset(charmap.categories())
        intervals = charmap.query(
            include_categories=whitelist_categories,
            exclude_categories=blacklist_categories,
            min_codepoint=min_codepoint,
            max_codepoint=max_codepoint,
            include_characters=whitelist_characters,
            exclude_characters=blacklist_characters,
        )
        if not intervals:
            arguments = [
                ("whitelist_categories", whitelist_categories),
                ("blacklist_categories", blacklist_categories),
                ("whitelist_characters", whitelist_characters),
                ("blacklist_characters", blacklist_characters),
                ("min_codepoint", min_codepoint),
                ("max_codepoint", max_codepoint),
            ]
            raise InvalidArgument(
                "No characters are allowed to be generated by this "
                "combination of arguments: "
                + ", ".join("%s=%r" % arg for arg in arguments if arg[1] is not None)
            )
        self.intervals = IntervalSet(intervals)
        self.zero_point = self.intervals.index_above(ord("0"))
        self.Z_point = min(
            self.intervals.index_above(ord("Z")), len(self.intervals) - 1
        )

    def do_draw(self, data):
        if len(self.intervals) > 256:
            if biased_coin(data, 0.2):
                i = integer_range(data, 256, len(self.intervals) - 1)
            else:
                i = integer_range(data, 0, 255)
        else:
            i = integer_range(data, 0, len(self.intervals) - 1)

        i = self.rewrite_integer(i)

        return chr(self.intervals[i])

    def rewrite_integer(self, i):
        # We would like it so that, where possible, shrinking replaces
        # characters with simple ascii characters, so we rejig this
        # bit so that the smallest values are 0, 1, 2, ..., Z.
        #
        # Imagine that numbers are laid out as abc0yyyZ...
        # this rearranges them so that they are laid out as
        # 0yyyZcba..., which gives a better shrinking order.
        if i <= self.Z_point:
            # We want to rewrite the integers [0, n] inclusive
            # to [zero_point, Z_point].
            n = self.Z_point - self.zero_point
            if i <= n:
                i += self.zero_point
            else:
                # We want to rewrite the integers [n + 1, Z_point] to
                # [zero_point, 0] (reversing the order so that codepoints below
                # zero_point shrink upwards).
                i = self.zero_point - (i - n)
                assert i < self.zero_point
            assert 0 <= i <= self.Z_point
        return i


class FixedSizeBytes(SearchStrategy):
    def __init__(self, size):
        self.size = size

    def do_draw(self, data):
        return bytes(data.draw_bytes(self.size))
