# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

import copy
import math
from collections.abc import Callable, Iterable
from typing import Any, overload

from hypothesis import strategies as st
from hypothesis.control import current_build_context
from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture import utils as cu
from hypothesis.internal.conjecture.data import ConjectureData
from hypothesis.internal.conjecture.engine import BUFFER_SIZE
from hypothesis.internal.conjecture.junkdrawer import LazySequenceCopy
from hypothesis.internal.conjecture.utils import combine_labels
from hypothesis.internal.filtering import get_integer_predicate_bounds
from hypothesis.internal.reflection import is_identity_function
from hypothesis.strategies._internal.strategies import (
    T3,
    T4,
    T5,
    Ex,
    FilteredStrategy,
    RecurT,
    SampledFromStrategy,
    SearchStrategy,
    T,
    check_strategy,
    filter_not_satisfied,
)
from hypothesis.strategies._internal.utils import cacheable, defines_strategy
from hypothesis.utils.conventions import UniqueIdentifier
from hypothesis.vendor.pretty import (
    ArgLabelsT,
    IDKey,
    _fixeddict_pprinter,
    _tuple_pprinter,
)


class TupleStrategy(SearchStrategy[tuple[Ex, ...]]):
    """A strategy responsible for fixed length tuples based on heterogeneous
    strategies for each of their elements."""

    def __init__(self, strategies: Iterable[SearchStrategy[Any]]):
        super().__init__()
        self.element_strategies = tuple(strategies)

    def do_validate(self) -> None:
        for s in self.element_strategies:
            s.validate()

    def calc_label(self) -> int:
        return combine_labels(
            self.class_label, *(s.label for s in self.element_strategies)
        )

    def __repr__(self) -> str:
        tuple_string = ", ".join(map(repr, self.element_strategies))
        return f"TupleStrategy(({tuple_string}))"

    def calc_has_reusable_values(self, recur: RecurT) -> bool:
        return all(recur(e) for e in self.element_strategies)

    def do_draw(self, data: ConjectureData) -> tuple[Ex, ...]:
        context = current_build_context()
        arg_labels: ArgLabelsT = {}
        result = []
        for i, strategy in enumerate(self.element_strategies):
            with context.track_arg_label(f"arg[{i}]") as arg_label:
                result.append(data.draw(strategy))
            arg_labels |= arg_label

        result = tuple(result)
        if arg_labels:
            context.known_object_printers[IDKey(result)].append(
                _tuple_pprinter(arg_labels)
            )
        return result

    def calc_is_empty(self, recur: RecurT) -> bool:
        return any(recur(e) for e in self.element_strategies)


@overload
def tuples() -> SearchStrategy[tuple[()]]:  # pragma: no cover
    ...


@overload
def tuples(__a1: SearchStrategy[Ex]) -> SearchStrategy[tuple[Ex]]:  # pragma: no cover
    ...


@overload
def tuples(
    __a1: SearchStrategy[Ex], __a2: SearchStrategy[T]
) -> SearchStrategy[tuple[Ex, T]]:  # pragma: no cover
    ...


@overload
def tuples(
    __a1: SearchStrategy[Ex], __a2: SearchStrategy[T], __a3: SearchStrategy[T3]
) -> SearchStrategy[tuple[Ex, T, T3]]:  # pragma: no cover
    ...


@overload
def tuples(
    __a1: SearchStrategy[Ex],
    __a2: SearchStrategy[T],
    __a3: SearchStrategy[T3],
    __a4: SearchStrategy[T4],
) -> SearchStrategy[tuple[Ex, T, T3, T4]]:  # pragma: no cover
    ...


@overload
def tuples(
    __a1: SearchStrategy[Ex],
    __a2: SearchStrategy[T],
    __a3: SearchStrategy[T3],
    __a4: SearchStrategy[T4],
    __a5: SearchStrategy[T5],
) -> SearchStrategy[tuple[Ex, T, T3, T4, T5]]:  # pragma: no cover
    ...


@overload
def tuples(
    *args: SearchStrategy[Any],
) -> SearchStrategy[tuple[Any, ...]]:  # pragma: no cover
    ...


@cacheable
@defines_strategy()
def tuples(*args: SearchStrategy[Any]) -> SearchStrategy[tuple[Any, ...]]:
    """Return a strategy which generates a tuple of the same length as args by
    generating the value at index i from args[i].

    e.g. tuples(integers(), integers()) would generate a tuple of length
    two with both values an integer.

    Examples from this strategy shrink by shrinking their component parts.
    """
    for arg in args:
        check_strategy(arg)

    return TupleStrategy(args)


class ListStrategy(SearchStrategy[list[Ex]]):
    """A strategy for lists which takes a strategy for its elements and the
    allowed lengths, and generates lists with the correct size and contents."""

    _nonempty_filters: tuple[Callable[[Any], Any], ...] = (bool, len, tuple, list)

    def __init__(
        self,
        elements: SearchStrategy[Ex],
        min_size: int = 0,
        max_size: float | int | None = math.inf,
    ):
        super().__init__()
        self.min_size = min_size or 0
        self.max_size = max_size if max_size is not None else math.inf
        assert 0 <= self.min_size <= self.max_size
        self.average_size = min(
            max(self.min_size * 2, self.min_size + 5),
            0.5 * (self.min_size + self.max_size),
        )
        self.element_strategy = elements
        if min_size > BUFFER_SIZE:
            raise InvalidArgument(
                f"{self!r} can never generate an example, because min_size is larger "
                "than Hypothesis supports.  Including it is at best slowing down your "
                "tests for no benefit; at worst making them fail (maybe flakily) with "
                "a HealthCheck error."
            )

    def calc_label(self) -> int:
        return combine_labels(self.class_label, self.element_strategy.label)

    def do_validate(self) -> None:
        self.element_strategy.validate()
        if self.is_empty:
            raise InvalidArgument(
                "Cannot create non-empty lists with elements drawn from "
                f"strategy {self.element_strategy!r} because it has no values."
            )
        if self.element_strategy.is_empty and 0 < self.max_size < float("inf"):
            raise InvalidArgument(
                f"Cannot create a collection of max_size={self.max_size!r}, "
                "because no elements can be drawn from the element strategy "
                f"{self.element_strategy!r}"
            )

    def calc_is_empty(self, recur: RecurT) -> bool:
        if self.min_size == 0:
            return False
        return recur(self.element_strategy)

    def do_draw(self, data: ConjectureData) -> list[Ex]:
        if self.element_strategy.is_empty:
            assert self.min_size == 0
            return []

        elements = cu.many(
            data,
            min_size=self.min_size,
            max_size=self.max_size,
            average_size=self.average_size,
        )
        result = []
        while elements.more():
            result.append(data.draw(self.element_strategy))
        return result

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}({self.element_strategy!r}, "
            f"min_size={self.min_size:_}, max_size={self.max_size:_})"
        )

    def filter(self, condition: Callable[[list[Ex]], Any]) -> SearchStrategy[list[Ex]]:
        if condition in self._nonempty_filters or is_identity_function(condition):
            assert self.max_size >= 1, "Always-empty is special cased in st.lists()"
            if self.min_size >= 1:
                return self
            new = copy.copy(self)
            new.min_size = 1
            return new

        constraints, pred = get_integer_predicate_bounds(condition)
        if constraints.get("len") and (
            "min_value" in constraints or "max_value" in constraints
        ):
            new = copy.copy(self)
            new.min_size = max(
                self.min_size, constraints.get("min_value", self.min_size)
            )
            new.max_size = min(
                self.max_size, constraints.get("max_value", self.max_size)
            )
            # Unsatisfiable filters are easiest to understand without rewriting.
            if new.min_size > new.max_size:
                return SearchStrategy.filter(self, condition)
            # Recompute average size; this is cheaper than making it into a property.
            new.average_size = min(
                max(new.min_size * 2, new.min_size + 5),
                0.5 * (new.min_size + new.max_size),
            )
            if pred is None:
                return new
            return SearchStrategy.filter(new, condition)

        return SearchStrategy.filter(self, condition)


class UniqueListStrategy(ListStrategy[Ex]):
    def __init__(
        self,
        elements: SearchStrategy[Ex],
        min_size: int,
        max_size: float | int | None,
        # TODO: keys are guaranteed to be Hashable, not just Any, but this makes
        # other things harder to type
        keys: tuple[Callable[[Ex], Any], ...],
        tuple_suffixes: SearchStrategy[tuple[Ex, ...]] | None,
    ):
        super().__init__(elements, min_size, max_size)
        self.keys = keys
        self.tuple_suffixes = tuple_suffixes

    def do_draw(self, data: ConjectureData) -> list[Ex]:
        if self.element_strategy.is_empty:
            assert self.min_size == 0
            return []

        elements = cu.many(
            data,
            min_size=self.min_size,
            max_size=self.max_size,
            average_size=self.average_size,
        )
        seen_sets: tuple[set[Ex], ...] = tuple(set() for _ in self.keys)
        # actually list[Ex], but if self.tuple_suffixes is present then Ex is a
        # tuple[T, ...] because self.element_strategy is a TuplesStrategy, and
        # appending a concrete tuple to `result: list[Ex]` makes mypy unhappy
        # without knowing that Ex = tuple.
        result: list[Any] = []

        # We construct a filtered strategy here rather than using a check-and-reject
        # approach because some strategies have special logic for generation under a
        # filter, and FilteredStrategy can consolidate multiple filters.
        def not_yet_in_unique_list(val: Ex) -> bool:  # type: ignore # covariant type param
            return all(
                key(val) not in seen
                for key, seen in zip(self.keys, seen_sets, strict=True)
            )

        filtered = FilteredStrategy(
            self.element_strategy, conditions=(not_yet_in_unique_list,)
        )
        while elements.more():
            value = filtered.do_filtered_draw(data)
            if value is filter_not_satisfied:
                elements.reject(f"Aborted test because unable to satisfy {filtered!r}")
            else:
                assert not isinstance(value, UniqueIdentifier)
                for key, seen in zip(self.keys, seen_sets, strict=True):
                    seen.add(key(value))
                if self.tuple_suffixes is not None:
                    value = (value, *data.draw(self.tuple_suffixes))  # type: ignore
                result.append(value)
        assert self.max_size >= len(result) >= self.min_size
        return result


class UniqueSampledListStrategy(UniqueListStrategy):
    def do_draw(self, data: ConjectureData) -> list[Ex]:
        assert isinstance(self.element_strategy, SampledFromStrategy)

        should_draw = cu.many(
            data,
            min_size=self.min_size,
            max_size=self.max_size,
            average_size=self.average_size,
        )
        seen_sets: tuple[set[Ex], ...] = tuple(set() for _ in self.keys)
        result: list[Any] = []

        remaining = LazySequenceCopy(self.element_strategy.elements)

        while remaining and should_draw.more():
            j = data.draw_integer(0, len(remaining) - 1)
            value = self.element_strategy._transform(remaining.pop(j))
            if value is not filter_not_satisfied and all(
                key(value) not in seen
                for key, seen in zip(self.keys, seen_sets, strict=True)
            ):
                for key, seen in zip(self.keys, seen_sets, strict=True):
                    seen.add(key(value))
                if self.tuple_suffixes is not None:
                    value = (value, *data.draw(self.tuple_suffixes))
                result.append(value)
            else:
                should_draw.reject(
                    "UniqueSampledListStrategy filter not satisfied or value already seen"
                )
        assert self.max_size >= len(result) >= self.min_size
        return result


class FixedDictStrategy(SearchStrategy[dict[Any, Any]]):
    """A strategy which produces dicts with a fixed set of keys, given a
    strategy for each of their equivalent values.

    e.g. {'foo' : some_int_strategy} would generate dicts with the single
    key 'foo' mapping to some integer.
    """

    def __init__(
        self,
        mapping: dict[Any, SearchStrategy[Any]],
        *,
        optional: dict[Any, SearchStrategy[Any]] | None,
    ):
        super().__init__()
        dict_type = type(mapping)
        self.mapping = mapping
        keys = tuple(mapping.keys())
        self.fixed = st.tuples(*[mapping[k] for k in keys]).map(
            lambda value: dict_type(zip(keys, value, strict=True))
        )
        self.optional = optional

    def do_draw(self, data: ConjectureData) -> dict[Any, Any]:
        context = current_build_context()
        arg_labels: ArgLabelsT = {}
        value = type(self.mapping)()

        for key, strategy in self.mapping.items():
            with context.track_arg_label(str(key)) as arg_label:
                value[key] = data.draw(strategy)
            arg_labels |= arg_label

        if self.optional is not None:
            remaining = [k for k, v in self.optional.items() if not v.is_empty]
            should_draw = cu.many(
                data,
                min_size=0,
                max_size=len(remaining),
                average_size=len(remaining) / 2,
            )
            while should_draw.more():
                j = data.draw_integer(0, len(remaining) - 1)
                remaining[-1], remaining[j] = remaining[j], remaining[-1]
                key = remaining.pop()
                with context.track_arg_label(str(key)) as arg_label:
                    value[key] = data.draw(self.optional[key])
                arg_labels |= arg_label

        if arg_labels:
            context.known_object_printers[IDKey(value)].append(
                _fixeddict_pprinter(arg_labels, self.mapping)
            )
        return value

    def calc_is_empty(self, recur: RecurT) -> bool:
        return recur(self.fixed)

    def __repr__(self) -> str:
        if self.optional is not None:
            return f"fixed_dictionaries({self.mapping!r}, optional={self.optional!r})"
        return f"fixed_dictionaries({self.mapping!r})"
