# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Most of this work is copyright (C) 2013-2020 David R. MacIver
# (david@drmaciver.com), but it contains contributions by others. See
# CONTRIBUTING.rst for a full list of people who may hold copyright, and
# consult the git log if you need to determine who owns an individual
# contribution.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
#
# END HEADER

from itertools import islice
from random import Random

from hypothesis import HealthCheck, Verbosity, settings, strategies as st
from hypothesis.errors import UnsatisfiedAssumption
from hypothesis.internal.conjecture.data import ConjectureData, Status
from hypothesis.internal.conjecture.dfa.lstar import LStar
from hypothesis.internal.conjecture.engine import BUFFER_SIZE, ConjectureRunner
from hypothesis.internal.conjecture.junkdrawer import uniform

LEARNERS = {}


def learner_for(strategy):
    """Returns an LStar learner that predicts whether a buffer
    corresponds to a discard free choice sequence leading to
    a valid value for this strategy."""
    try:
        return LEARNERS[strategy]
    except KeyError:
        pass

    def test_function(data):
        try:
            data.draw(strategy)
        except UnsatisfiedAssumption:
            data.mark_invalid()
        data.mark_interesting()

    runner = ConjectureRunner(
        test_function,
        settings=settings(
            database=None,
            verbosity=Verbosity.quiet,
            suppress_health_check=HealthCheck.all(),
        ),
        random=Random(0),
        ignore_limits=True,
    )

    def predicate(s):
        result = runner.cached_test_function(s)
        if result.status < Status.VALID:
            return False
        if result.has_discards:
            return False
        return result.buffer == s

    learner = LStar(predicate)

    runner.run()

    (v,) = runner.interesting_examples.values()

    # We make sure the learner has properly learned small examples.
    # This is all fairly ad hoc but is mostly designed to get it
    # to understand what the smallest example is and avoid any
    # loops at the beginning of the DFA that don't really exist.
    learner.learn(v.buffer)

    for n in [1, 2, 3]:
        for _ in range(5):
            learner.learn(uniform(runner.random, n) + v.buffer)

    prev = -1
    while learner.generation != prev:
        prev = learner.generation

        for _ in range(10):
            s = uniform(runner.random, len(v.buffer)) + bytes(BUFFER_SIZE)
            learner.learn(s)
            data = runner.cached_test_function(s)
            if data.status >= Status.VALID:
                learner.learn(data.buffer)

    LEARNERS[strategy] = learner
    return learner


def iter_values(strategy, unique_by=lambda s: s):
    """Iterate over the values that can be generated by ``strategy``
    in what is, as best as we can figure, shortlex-ascending order.

    The same value may have multiple, redundant, representations,
    and we don't want to yield it more than once, so we deduplicate.
    If the value is not hashable, pass some suitable key (e.g. repr)
    as unique_by.
    """
    learner = learner_for(strategy)

    seen = set()

    while True:
        for s in learner.dfa.all_matching_strings():
            if not learner.member(s):
                # This string matched the DFA but didn't
                # satisfy the membership test. We relearn
                # the string, improving our learner, and
                # restart the loop.
                learner.learn(s)
                break
            result = ConjectureData.for_buffer(s).draw(strategy)
            key = unique_by(result)
            if key in seen:
                continue
            seen.add(key)
            yield result
        else:
            break


def test_characters_start_with_the_digits():
    assert list(islice(iter_values(st.characters()), 10)) == [
        "0",
        "1",
        "2",
        "3",
        "4",
        "5",
        "6",
        "7",
        "8",
        "9",
    ]
