# Designing a better simplifier

This is a notebook talking through some of the considerations in the design of Hypothesis's approach to simplification.

It doesn't perfectly mirror what actually happens in Hypothesis, but it should give some consideration to the sort of things that Hypothesis does and why it takes a particular approach.

In order to simplify the scope of this document we are only going to
concern ourselves with lists of integers. There are a number of API considerations involved in expanding beyond that point, however most of the algorithmic considerations are the same.

The big difference between lists of integers and the general case is that integers can never be too complex. In particular we will rapidly get to the point where individual elements can be simplified in usually only log(n) calls. When dealing with e.g. lists of lists this is a much more complicated proposition. That may be covered in another notebook.

Our objective here is to minimize the number of times we check the condition. We won't be looking at actual timing performance, because usually the speed of the condition is the bottleneck there (and where it's not, everything is fast enough that we need not worry).

In [1]:
def greedy_shrink(ls, constraint, shrink):
    """
    This is the "classic" QuickCheck algorithm which takes a shrink function
    which will iterate over simpler versions of an example. We are trying
    to find a local minima: That is an example ls such that condition(ls)
    is True but that constraint(t) is False for each t in shrink(ls).
    """
    while True:
        for s in shrink(ls):
            if constraint(s):
                ls = s
                break
        else:
            return ls

In [2]:
def shrink1(ls):
    """
    This is our prototype shrink function. It is very bad. It makes the
    mistake of only making very small changes to an example each time.

    Most people write something like this the first time they come to
    implement example shrinking. In particular early Hypothesis very much
    made this mistake.

    What this does:

    For each index, if the value of the index is non-zero we try
    decrementing it by 1.

    We then (regardless of if it's zero) try the list with the value at
    that index deleted.
    """
    for i in range(len(ls)):
        s = list(ls)
        if s[i] > 0:
            s[i] -= 1
            yield list(s)
        del s[i]
        yield list(s)

In [3]:
def show_trace(start, constraint, simplifier):
    """
    This is a debug function. You shouldn't concern yourself with
    its implementation too much.

    What it does is print out every intermediate step in applying a
    simplifier (a function of the form (list, constraint) -> list)
    along with whether it is a successful shrink or not.
    """
    if start is None:
        while True:
            start = gen_list()
            if constraint(start):
                break

    shrinks = [0]
    tests = [0]

    def print_shrink(ls):
        tests[0] += 1
        if constraint(ls):
            shrinks[0] += 1
            print("✓", ls)
            return True
        else:
            print("✗", ls)
        return False

    print("✓", start)
    simplifier(start, print_shrink)
    print()
    print("%d shrinks with %d function calls" % (shrinks[0], tests[0]))

In [4]:
from functools import partial

In [5]:
show_trace([5, 5], lambda x: len(x) >= 2, partial(greedy_shrink, shrink=shrink1))

✓ [5, 5]
✓ [4, 5]
✓ [3, 5]
✓ [2, 5]
✓ [1, 5]
✓ [0, 5]
✗ [5]
✓ [0, 4]
✗ [4]
✓ [0, 3]
✗ [3]
✓ [0, 2]
✗ [2]
✓ [0, 1]
✗ [1]
✓ [0, 0]
✗ [0]
✗ [0]

10 shrinks with 17 function calls


That worked reasonably well, but it sure was a lot of function calls for such a small amount of shrinking. What would have happened if we'd started with [100, 100]?

In [6]:
def shrink2(ls):
    """
    Here is an improved shrink function. We first try deleting each element
    and then we try making each element smaller, but we do so from the left
    hand side instead of the right. This means we will always find the
    smallest value that can go in there, but we will do so much sooner.
    """
    for i in range(len(ls)):
        s = list(ls)
        del s[i]
        yield list(s)

    for i in range(len(ls)):
        for x in range(ls[i]):
            s = list(ls)
            s[i] = x
            yield s

In [7]:
show_trace([5, 5], lambda x: len(x) >= 2, partial(greedy_shrink, shrink=shrink2))

✓ [5, 5]
✗ [5]
✗ [5]
✓ [0, 5]
✗ [5]
✗ [0]
✓ [0, 0]
✗ [0]
✗ [0]

2 shrinks with 8 function calls


This did indeed reduce the number of function calls significantly - we immediately determine that the value in the cell doesn't matter and we can just put zero there. 

But what would have happened if the value *did* matter?

In [8]:
show_trace([1000], lambda x: sum(x) >= 500, partial(greedy_shrink, shrink=shrink2))

✓ [1000]
✗ []
✗ [0]
✗ [1]
✗ [2]
✗ [3]
✗ [4]
✗ [5]
✗ [6]
✗ [7]
✗ [8]
✗ [9]
✗ [10]
✗ [11]
✗ [12]
✗ [13]
✗ [14]
✗ [15]
✗ [16]
✗ [17]
✗ [18]
✗ [19]
✗ [20]
✗ [21]
✗ [22]
✗ [23]
✗ [24]
✗ [25]
✗ [26]
✗ [27]
✗ [28]
✗ [29]
✗ [30]
✗ [31]
✗ [32]
✗ [33]
✗ [34]
✗ [35]
✗ [36]
✗ [37]
✗ [38]
✗ [39]
✗ [40]
✗ [41]
✗ [42]
✗ [43]
✗ [44]
✗ [45]
✗ [46]
✗ [47]
✗ [48]
✗ [49]
✗ [50]
✗ [51]
✗ [52]
✗ [53]
✗ [54]
✗ [55]
✗ [56]
✗ [57]
✗ [58]
✗ [59]
✗ [60]
✗ [61]
✗ [62]
✗ [63]
✗ [64]
✗ [65]
✗ [66]
✗ [67]
✗ [68]
✗ [69]
✗ [70]
✗ [71]
✗ [72]
✗ [73]
✗ [74]
✗ [75]
✗ [76]
✗ [77]
✗ [78]
✗ [79]
✗ [80]
✗ [81]
✗ [82]
✗ [83]
✗ [84]
✗ [85]
✗ [86]
✗ [87]
✗ [88]
✗ [89]
✗ [90]
✗ [91]
✗ [92]
✗ [93]
✗ [94]
✗ [95]
✗ [96]
✗ [97]
✗ [98]
✗ [99]
✗ [100]
✗ [101]
✗ [102]
✗ [103]
✗ [104]
✗ [105]
✗ [106]
✗ [107]
✗ [108]
✗ [109]
✗ [110]
✗ [111]
✗ [112]
✗ [113]
✗ [114]
✗ [115]
✗ [116]
✗ [117]
✗ [118]
✗ [119]
✗ [120]
✗ [121]
✗ [122]
✗ [123]
✗ [124]
✗ [125]
✗ [126]
✗ [127]
✗ [128]
✗ [129]
✗ [130]
✗ [131]
✗ [132]
✗ [133]
✗ [134]
✗ [135]
✗ [136]


Because we're trying every intermediate value, what we have amounts to a linear probe up to the smallest value that will work. If that smallest value is large, this will take a long time. Our shrinking is still O(n), but n is now the size of the smallest value that will work rather than the starting value. This is still pretty suboptimal.

What we want to do is try to replace our linear probe with a binary search. What we'll get isn't exactly a binary search, but it's close enough.

In [9]:
def shrink_integer(n):
    """
    Shrinker for individual integers.

    What happens is that we start from the left, first probing upwards in powers of two.

    When this would take us past our target value we then binary chop towards it.
    """
    if not n:
        return
    for k in range(64):
        probe = 2**k
        if probe >= n:
            break
        yield probe - 1
    probe //= 2
    while True:
        probe = (probe + n) // 2
        yield probe
        if probe == n - 1:
            break


def shrink3(ls):
    for i in range(len(ls)):
        s = list(ls)
        del s[i]
        yield list(s)
        for x in shrink_integer(ls[i]):
            s = list(ls)
            s[i] = x
            yield s

In [10]:
list(shrink_integer(500))

[0, 1, 3, 7, 15, 31, 63, 127, 255, 378, 439, 469, 484, 492, 496, 498, 499]

This gives us a reasonable distribution of O(log(n)) values in the middle while still making sure we start with 0 and finish with n - 1.

In Hypothesis's actual implementation we also try random values in the probe region in case there's something special about things near powers of two, but we won't worry about that here.

In [11]:
show_trace([1000], lambda x: sum(x) >= 500, partial(greedy_shrink, shrink=shrink3))

✓ [1000]
✗ []
✗ [0]
✗ [1]
✗ [3]
✗ [7]
✗ [15]
✗ [31]
✗ [63]
✗ [127]
✗ [255]
✓ [511]
✗ []
✗ [0]
✗ [1]
✗ [3]
✗ [7]
✗ [15]
✗ [31]
✗ [63]
✗ [127]
✗ [255]
✗ [383]
✗ [447]
✗ [479]
✗ [495]
✓ [503]
✗ []
✗ [0]
✗ [1]
✗ [3]
✗ [7]
✗ [15]
✗ [31]
✗ [63]
✗ [127]
✗ [255]
✗ [379]
✗ [441]
✗ [472]
✗ [487]
✗ [495]
✗ [499]
✓ [501]
✗ []
✗ [0]
✗ [1]
✗ [3]
✗ [7]
✗ [15]
✗ [31]
✗ [63]
✗ [127]
✗ [255]
✗ [378]
✗ [439]
✗ [470]
✗ [485]
✗ [493]
✗ [497]
✗ [499]
✓ [500]
✗ []
✗ [0]
✗ [1]
✗ [3]
✗ [7]
✗ [15]
✗ [31]
✗ [63]
✗ [127]
✗ [255]
✗ [378]
✗ [439]
✗ [469]
✗ [484]
✗ [492]
✗ [496]
✗ [498]
✗ [499]

4 shrinks with 79 function calls


This now runs in a much more reasonable number of function calls.

Now we want to look at how to reduce the number of elements in the list more efficiently. We're currently making the same mistake we did with n umbers. Only reducing one at a time.

In [12]:
show_trace([2] * 20, lambda x: sum(x) >= 3, partial(greedy_shrink, shrink=shrink3))

✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2, 2]
✓ [2, 2, 2, 2, 2]
✓ [2, 2, 2, 2]
✓ [2, 2, 2]
✓ [2, 2]
✗ [2]
✗ [0, 2]
✓ [1, 2]
✗ [2]
✗ [0, 2]
✗ [1]
✗ [1, 0]
✗ [1, 1]

19 shrinks with 26 function calls


We won't try too hard here, because typically our lists are not *that* long. We will just attempt to start by finding a shortish initial prefix that demonstrates the behaviour:

In [13]:
def shrink_to_prefix(ls):
    i = 1
    while i < len(ls):
        yield ls[:i]
        i *= 2


def delete_individual_elements(ls):
    for i in range(len(ls)):
        s = list(ls)
        del s[i]
        yield list(s)


def shrink_individual_elements(ls):
    for i in range(len(ls)):
        for x in shrink_integer(ls[i]):
            s = list(ls)
            s[i] = x
            yield s


def shrink4(ls):
    yield from shrink_to_prefix(ls)
    yield from delete_individual_elements(ls)
    yield from shrink_individual_elements(ls)

In [14]:
show_trace([2] * 20, lambda x: sum(x) >= 3, partial(greedy_shrink, shrink=shrink4))

✓ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
✗ [2]
✓ [2, 2]
✗ [2]
✗ [2]
✗ [2]
✗ [0, 2]
✓ [1, 2]
✗ [1]
✗ [2]
✗ [1]
✗ [0, 2]
✗ [1, 0]
✗ [1, 1]

2 shrinks with 13 function calls


The problem we now want to address is the fact that when we're shrinking elements we're only shrinking them one at a time. This means that even though we're only O(log(k)) in each element, we're O(log(k)^n) in the whole list where n is the length of the list. For even very modest k this is bad.

In general we may not be able to fix this, but in practice for a lot of common structures we can exploit similarity to try to do simultaneous shrinking.

Here is our starting example: We start and finish with all identical values. We would like to be able to shortcut through a lot of the uninteresting intermediate examples somehow.

In [15]:
show_trace(
    [20] * 7,
    lambda x: len([t for t in x if t >= 5]) >= 5,
    partial(greedy_shrink, shrink=shrink4),
)

✓ [20, 20, 20, 20, 20, 20, 20]
✗ [20]
✗ [20, 20]
✗ [20, 20, 20, 20]
✓ [20, 20, 20, 20, 20, 20]
✗ [20]
✗ [20, 20]
✗ [20, 20, 20, 20]
✓ [20, 20, 20, 20, 20]
✗ [20]
✗ [20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [0, 20, 20, 20, 20]
✗ [1, 20, 20, 20, 20]
✗ [3, 20, 20, 20, 20]
✓ [7, 20, 20, 20, 20]
✗ [7]
✗ [7, 20]
✗ [7, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [7, 20, 20, 20]
✗ [7, 20, 20, 20]
✗ [7, 20, 20, 20]
✗ [7, 20, 20, 20]
✗ [0, 20, 20, 20, 20]
✗ [1, 20, 20, 20, 20]
✗ [3, 20, 20, 20, 20]
✓ [5, 20, 20, 20, 20]
✗ [5]
✗ [5, 20]
✗ [5, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [5, 20, 20, 20]
✗ [5, 20, 20, 20]
✗ [5, 20, 20, 20]
✗ [5, 20, 20, 20]
✗ [0, 20, 20, 20, 20]
✗ [1, 20, 20, 20, 20]
✗ [3, 20, 20, 20, 20]
✗ [4, 20, 20, 20, 20]
✗ [5, 0, 20, 20, 20]
✗ [5, 1, 20, 20, 20]
✗ [5, 3, 20, 20, 20]
✓ [5, 7, 20, 20, 20]
✗ [5]
✗ [5, 7]
✗ [5, 7, 20, 20]
✗ [7, 20, 20, 20]
✗ [5, 20, 20, 20]
✗ [5, 7, 20, 20]
✗ [5, 7, 20, 20]
✗ [5, 7, 20, 

In [16]:
def shrink_shared(ls):
    """
    Look for all sets of shared indices and try to perform a simultaneous shrink on
    their value, replacing all of them at once.

    In actual Hypothesis we also try replacing only subsets of the values when there
    are more than two shared values, but we won't worry about that here.
    """
    shared_indices = {}
    for i in range(len(ls)):
        shared_indices.setdefault(ls[i], []).append(i)
    for sharing in shared_indices.values():
        if len(sharing) > 1:
            for v in shrink_integer(ls[sharing[0]]):
                s = list(ls)
                for i in sharing:
                    s[i] = v
                yield s


def shrink5(ls):
    yield from shrink_to_prefix(ls)
    yield from delete_individual_elements(ls)
    yield from shrink_shared(ls)
    yield from shrink_individual_elements(ls)

In [17]:
show_trace(
    [20] * 7,
    lambda x: len([t for t in x if t >= 5]) >= 5,
    partial(greedy_shrink, shrink=shrink5),
)

✓ [20, 20, 20, 20, 20, 20, 20]
✗ [20]
✗ [20, 20]
✗ [20, 20, 20, 20]
✓ [20, 20, 20, 20, 20, 20]
✗ [20]
✗ [20, 20]
✗ [20, 20, 20, 20]
✓ [20, 20, 20, 20, 20]
✗ [20]
✗ [20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [20, 20, 20, 20]
✗ [0, 0, 0, 0, 0]
✗ [1, 1, 1, 1, 1]
✗ [3, 3, 3, 3, 3]
✓ [7, 7, 7, 7, 7]
✗ [7]
✗ [7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [0, 0, 0, 0, 0]
✗ [1, 1, 1, 1, 1]
✗ [3, 3, 3, 3, 3]
✓ [5, 5, 5, 5, 5]
✗ [5]
✗ [5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [0, 0, 0, 0, 0]
✗ [1, 1, 1, 1, 1]
✗ [3, 3, 3, 3, 3]
✗ [4, 4, 4, 4, 4]
✗ [0, 5, 5, 5, 5]
✗ [1, 5, 5, 5, 5]
✗ [3, 5, 5, 5, 5]
✗ [4, 5, 5, 5, 5]
✗ [5, 0, 5, 5, 5]
✗ [5, 1, 5, 5, 5]
✗ [5, 3, 5, 5, 5]
✗ [5, 4, 5, 5, 5]
✗ [5, 5, 0, 5, 5]
✗ [5, 5, 1, 5, 5]
✗ [5, 5, 3, 5, 5]
✗ [5, 5, 4, 5, 5]
✗ [5, 5, 5, 0, 5]
✗ [5, 5, 5, 1, 5]
✗ [5, 5, 5, 3, 5]
✗ [5, 5, 5, 4, 5]

This achieves the desired result. We rapidly progress through all of the intermediate stages. We do still have to perform individual shrinks at the end unfortunately (this is unavoidable), but the size of the elements is much smaller now so it takes less time.

Unfortunately while this solves the problem in this case it's almost useless, because unless you find yourself in the exact right starting position it never does anything.

In [18]:
show_trace(
    [20 + i for i in range(7)],
    lambda x: len([t for t in x if t >= 5]) >= 5,
    partial(greedy_shrink, shrink=shrink5),
)

✓ [20, 21, 22, 23, 24, 25, 26]
✗ [20]
✗ [20, 21]
✗ [20, 21, 22, 23]
✓ [21, 22, 23, 24, 25, 26]
✗ [21]
✗ [21, 22]
✗ [21, 22, 23, 24]
✓ [22, 23, 24, 25, 26]
✗ [22]
✗ [22, 23]
✗ [22, 23, 24, 25]
✗ [23, 24, 25, 26]
✗ [22, 24, 25, 26]
✗ [22, 23, 25, 26]
✗ [22, 23, 24, 26]
✗ [22, 23, 24, 25]
✗ [0, 23, 24, 25, 26]
✗ [1, 23, 24, 25, 26]
✗ [3, 23, 24, 25, 26]
✓ [7, 23, 24, 25, 26]
✗ [7]
✗ [7, 23]
✗ [7, 23, 24, 25]
✗ [23, 24, 25, 26]
✗ [7, 24, 25, 26]
✗ [7, 23, 25, 26]
✗ [7, 23, 24, 26]
✗ [7, 23, 24, 25]
✗ [0, 23, 24, 25, 26]
✗ [1, 23, 24, 25, 26]
✗ [3, 23, 24, 25, 26]
✓ [5, 23, 24, 25, 26]
✗ [5]
✗ [5, 23]
✗ [5, 23, 24, 25]
✗ [23, 24, 25, 26]
✗ [5, 24, 25, 26]
✗ [5, 23, 25, 26]
✗ [5, 23, 24, 26]
✗ [5, 23, 24, 25]
✗ [0, 23, 24, 25, 26]
✗ [1, 23, 24, 25, 26]
✗ [3, 23, 24, 25, 26]
✗ [4, 23, 24, 25, 26]
✗ [5, 0, 24, 25, 26]
✗ [5, 1, 24, 25, 26]
✗ [5, 3, 24, 25, 26]
✓ [5, 7, 24, 25, 26]
✗ [5]
✗ [5, 7]
✗ [5, 7, 24, 25]
✗ [7, 24, 25, 26]
✗ [5, 24, 25, 26]
✗ [5, 7, 25, 26]
✗ [5, 7, 24, 26]
✗ [5, 7, 24, 

So what we're going to try to do is to try a simplification first which *creates* that exact right starting condition. Further it's one that will be potentially very useful even if we don't actually have the situation where we have shared shrinks.

What we're going to do is we're going to use values from the list to act as evidence for how complex things need to be. Starting from the smallest, we'll try capping the array at each individual value and see what happens.

As well as being potentially a very rapid shrink, this creates lists with lots of duplicates, which enables the simultaneous shrinking to shine.

In [19]:
def replace_with_simpler(ls):
    if not ls:
        return
    values = set(ls)
    values.remove(max(ls))
    values = sorted(values)
    for v in values:
        yield [min(v, l) for l in ls]


def shrink6(ls):
    yield from shrink_to_prefix(ls)
    yield from delete_individual_elements(ls)
    yield from replace_with_simpler(ls)
    yield from shrink_shared(ls)
    yield from shrink_individual_elements(ls)

In [20]:
show_trace(
    [20 + i for i in range(7)],
    lambda x: len([t for t in x if t >= 5]) >= 5,
    partial(greedy_shrink, shrink=shrink6),
)

✓ [20, 21, 22, 23, 24, 25, 26]
✗ [20]
✗ [20, 21]
✗ [20, 21, 22, 23]
✓ [21, 22, 23, 24, 25, 26]
✗ [21]
✗ [21, 22]
✗ [21, 22, 23, 24]
✓ [22, 23, 24, 25, 26]
✗ [22]
✗ [22, 23]
✗ [22, 23, 24, 25]
✗ [23, 24, 25, 26]
✗ [22, 24, 25, 26]
✗ [22, 23, 25, 26]
✗ [22, 23, 24, 26]
✗ [22, 23, 24, 25]
✓ [22, 22, 22, 22, 22]
✗ [22]
✗ [22, 22]
✗ [22, 22, 22, 22]
✗ [22, 22, 22, 22]
✗ [22, 22, 22, 22]
✗ [22, 22, 22, 22]
✗ [22, 22, 22, 22]
✗ [22, 22, 22, 22]
✗ [0, 0, 0, 0, 0]
✗ [1, 1, 1, 1, 1]
✗ [3, 3, 3, 3, 3]
✓ [7, 7, 7, 7, 7]
✗ [7]
✗ [7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [7, 7, 7, 7]
✗ [0, 0, 0, 0, 0]
✗ [1, 1, 1, 1, 1]
✗ [3, 3, 3, 3, 3]
✓ [5, 5, 5, 5, 5]
✗ [5]
✗ [5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [5, 5, 5, 5]
✗ [0, 0, 0, 0, 0]
✗ [1, 1, 1, 1, 1]
✗ [3, 3, 3, 3, 3]
✗ [4, 4, 4, 4, 4]
✗ [0, 5, 5, 5, 5]
✗ [1, 5, 5, 5, 5]
✗ [3, 5, 5, 5, 5]
✗ [4, 5, 5, 5, 5]
✗ [5, 0, 5, 5, 5]
✗ [5, 1, 5, 5, 5]
✗ [5, 3, 5, 5, 5]
✗ [5, 

Now we're going to start looking at some numbers.

What we'll do is we'll generate 1000 random lists satisfying some predicate, and then simplify them down to the smallest possible examples satisfying those predicates. This lets us verify that these aren't just cherry-picked examples and our methods help in the general case. We fix the set of examples per predicate so that we're comparing like for like.

A more proper statistical treatment would probably be a good idea.

In [21]:
from collections import OrderedDict

conditions = OrderedDict(
    [
        ("length >= 2", lambda xs: len(xs) >= 2),
        ("sum >= 500", lambda xs: sum(xs) >= 500),
        ("sum >= 3", lambda xs: sum(xs) >= 3),
        ("At least 10 by 5", lambda xs: len([t for t in xs if t >= 5]) >= 10),
    ]
)

In [22]:
import random

N_EXAMPLES = 1000

datasets = {}


def gen_list(rnd):
    return [random.getrandbits(64) for _ in range(random.randint(0, 100))]


def dataset_for(condition):
    if condition in datasets:
        return datasets[condition]
    constraint = conditions[condition]
    dataset = []
    rnd = random.Random(condition)
    while len(dataset) < N_EXAMPLES:
        ls = gen_list(rnd)
        if constraint(ls):
            dataset.append(ls)
    datasets[condition] = dataset
    return dataset


dataset_for("sum >= 3")[1]

[17861213645196285187,
 15609796832515195084,
 8808697621832673046,
 1013319847337885109,
 1252281976438780211,
 15526909770962854196,
 2065337703776048239,
 11654092230944134701,
 5554896851708700201,
 17485190250805381572,
 7700396730246958474,
 402840882133605445,
 5303116940477413125,
 7459257850255946545,
 10349184495871650178,
 4361155591615075311,
 15194020468024244632,
 14428821588688846242,
 5754975712549869618,
 13740966788951413307,
 15209704957418077856,
 12562588328524673262,
 8415556016795311987,
 3993098291779210741,
 16874756914619597640,
 7932421182532982309,
 1080869529149674704,
 13878842261614060122,
 229976195287031921,
 8378461140013520338,
 6189522326946191255,
 16684625600934047114,
 12533448641134015292,
 10459192142175991903,
 15688511015570391481,
 3091340728247101611,
 4034760776171697910,
 6258572097778886531,
 13555449085571665140,
 6727488149749641424,
 7125107819562430884,
 1557872425804423698,
 4810250441100696888,
 10500486959813930693,
 84130006940364

In [23]:
# In order to avoid run-away cases where things will take basically forever
# we cap at 5000 as "you've taken too long. Stop it". Because we're only ever
# showing the worst case scenario we'll just display this as > 5000 if we ever
# hit it and it won't distort statistics.
MAX_COUNT = 5000


class MaximumCountExceeded(Exception):
    pass


def call_counts(condition, simplifier):
    constraint = conditions[condition]
    dataset = dataset_for(condition)
    counts = []

    for ex in dataset:
        counter = 0

        def run_and_count(ls):
            nonlocal counter
            counter += 1
            if counter > MAX_COUNT:
                raise MaximumCountExceeded
            return constraint(ls)

        try:
            simplifier(ex, run_and_count)
            counts.append(counter)
        except MaximumCountExceeded:
            counts.append(MAX_COUNT + 1)
            break
    return counts


def worst_case(condition, simplifier):
    return max(call_counts(condition, simplifier))


worst_case("length >= 2", partial(greedy_shrink, shrink=shrink6))

13

In [24]:
from IPython.display import HTML


def compare_simplifiers(named_simplifiers):
    """
    Given a list of (name, simplifier) pairs, output a table comparing
    the worst case performance of each on our current set of examples.
    """
    html_fragments = []
    html_fragments.append("<table>\n<thead>\n<tr>")
    header = ["Condition"]
    header.extend(name for name, _ in named_simplifiers)
    for h in header:
        html_fragments.append("<th>%s</th>" % (h,))
    html_fragments.append("</tr>\n</thead>\n<tbody>")

    for name in conditions:
        bits = [name.replace(">", "&gt;")]
        for _, simplifier in named_simplifiers:
            value = worst_case(name, simplifier)
            if value <= MAX_COUNT:
                bits.append(str(value))
            else:
                bits.append(" &gt; %d" % (MAX_COUNT,))
        html_fragments.append("<tr>  ")
        html_fragments.append(" ".join("<td>%s</td>" % (b,) for b in bits))
        html_fragments.append("</tr>")
    html_fragments.append("</tbody>\n</table>")
    return HTML("\n".join(html_fragments))

In [25]:
compare_simplifiers(
    [
        (f.__name__[-1], partial(greedy_shrink, shrink=f))
        for f in [shrink2, shrink3, shrink4, shrink5, shrink6]
    ]
)

Condition,2,3,4,5,6
length >= 2,106,105,13,13,13
sum >= 500,1102,178,80,80,80
sum >= 3,108,107,9,9,9
At least 10 by 5,535,690,809,877,144


So you can see from the above table, the iterations 2 through 5 were a little ambiguous ion that they helped a lot in the cases they were designed to help with but hurt in other cases. 6 however is clearly the best of the lot, being no worse than any of the others on any of the cases and often significantly better.

Rather than continuing to refine our shrink further, we instead look to improvements to how we use shrinking. We'll start by noting a simple optimization: If you look at our traces above, we often checked the same example twice. We're only interested in deterministic conditions, so this isn't useful to do. So we'll start by simply pruning out all duplicates. This should have exactly the same set and order of successful shrinks but will avoid a bunch of redundant work.

In [26]:
def greedy_shrink_with_dedupe(ls, constraint, shrink):
    seen = set()
    while True:
        for s in shrink(ls):
            key = tuple(s)
            if key in seen:
                continue
            seen.add(key)
            if constraint(s):
                ls = s
                break
        else:
            return ls

In [27]:
compare_simplifiers(
    [
        ("Normal", partial(greedy_shrink, shrink=shrink6)),
        ("Deduped", partial(greedy_shrink_with_dedupe, shrink=shrink6)),
    ]
)

Condition,Normal,Deduped
length >= 2,13,6
sum >= 500,80,35
sum >= 3,9,6
At least 10 by 5,144,107


As expected, this is a significant improvement in some cases. It is logically impossible that it could ever make things worse, but it's nice that it makes it better.

So far we've only looked at things where the interaction between elements was fairly light - the sum cases the values of other elements mattered a bit, but shrinking an integer could never enable other shrinks. Let's look at one where this is not the case: Where our condition is that we have at least 10 distinct elements.

In [28]:
show_trace(
    [100 + i for i in range(10)],
    lambda x: len(set(x)) >= 10,
    partial(greedy_shrink, shrink=shrink6),
)

✓ [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
✗ [100]
✗ [100, 101]
✗ [100, 101, 102, 103]
✗ [100, 101, 102, 103, 104, 105, 106, 107]
✗ [101, 102, 103, 104, 105, 106, 107, 108, 109]
✗ [100, 102, 103, 104, 105, 106, 107, 108, 109]
✗ [100, 101, 103, 104, 105, 106, 107, 108, 109]
✗ [100, 101, 102, 104, 105, 106, 107, 108, 109]
✗ [100, 101, 102, 103, 105, 106, 107, 108, 109]
✗ [100, 101, 102, 103, 104, 106, 107, 108, 109]
✗ [100, 101, 102, 103, 104, 105, 107, 108, 109]
✗ [100, 101, 102, 103, 104, 105, 106, 108, 109]
✗ [100, 101, 102, 103, 104, 105, 106, 107, 109]
✗ [100, 101, 102, 103, 104, 105, 106, 107, 108]
✗ [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
✗ [100, 101, 101, 101, 101, 101, 101, 101, 101, 101]
✗ [100, 101, 102, 102, 102, 102, 102, 102, 102, 102]
✗ [100, 101, 102, 103, 103, 103, 103, 103, 103, 103]
✗ [100, 101, 102, 103, 104, 104, 104, 104, 104, 104]
✗ [100, 101, 102, 103, 104, 105, 105, 105, 105, 105]
✗ [100, 101, 102, 103, 104, 105, 106, 106, 106, 106]
✗ [100, 1

This does not do very well at all.

The reason it doesn't is that we keep trying useless shrinks. e.g. none of the shrinks done by shrink\_to\_prefix, replace\_with\_simpler or shrink\_shared will ever do anything useful here.

So let's switch to an approach where we try shrink types until they stop working and then we move on to the next type:

In [29]:
def multicourse_shrink1(ls, constraint):
    seen = set()
    for shrink in [
        shrink_to_prefix,
        replace_with_simpler,
        shrink_shared,
        shrink_individual_elements,
    ]:
        while True:
            for s in shrink(ls):
                key = tuple(s)
                if key in seen:
                    continue
                seen.add(key)
                if constraint(s):
                    ls = s
                    break
            else:
                break
    return ls

In [30]:
show_trace(
    [100 + i for i in range(10)], lambda x: len(set(x)) >= 10, multicourse_shrink1
)

✓ [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
✗ [100]
✗ [100, 101]
✗ [100, 101, 102, 103]
✗ [100, 101, 102, 103, 104, 105, 106, 107]
✗ [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
✗ [100, 101, 101, 101, 101, 101, 101, 101, 101, 101]
✗ [100, 101, 102, 102, 102, 102, 102, 102, 102, 102]
✗ [100, 101, 102, 103, 103, 103, 103, 103, 103, 103]
✗ [100, 101, 102, 103, 104, 104, 104, 104, 104, 104]
✗ [100, 101, 102, 103, 104, 105, 105, 105, 105, 105]
✗ [100, 101, 102, 103, 104, 105, 106, 106, 106, 106]
✗ [100, 101, 102, 103, 104, 105, 106, 107, 107, 107]
✗ [100, 101, 102, 103, 104, 105, 106, 107, 108, 108]
✓ [0, 101, 102, 103, 104, 105, 106, 107, 108, 109]
✗ [0, 0, 102, 103, 104, 105, 106, 107, 108, 109]
✓ [0, 1, 102, 103, 104, 105, 106, 107, 108, 109]
✗ [0, 1, 0, 103, 104, 105, 106, 107, 108, 109]
✗ [0, 1, 1, 103, 104, 105, 106, 107, 108, 109]
✓ [0, 1, 3, 103, 104, 105, 106, 107, 108, 109]
✗ [0, 0, 3, 103, 104, 105, 106, 107, 108, 109]
✓ [0, 1, 2, 103, 104, 105, 106, 107, 108, 109]

In [31]:
conditions["10 distinct elements"] = lambda xs: len(set(xs)) >= 10

In [32]:
compare_simplifiers(
    [
        ("Single pass", partial(greedy_shrink_with_dedupe, shrink=shrink6)),
        ("Multi pass", multicourse_shrink1),
    ]
)

Condition,Single pass,Multi pass
length >= 2,6,4
sum >= 500,35,34
sum >= 3,6,5
At least 10 by 5,107,58
10 distinct elements,623,320


So that helped, but not as much as we'd have liked. It's saved us about half the calls, when really we wanted to save 90% of the calls.

We're on the right track though. The problem is not that our solution isn't good, it's that it didn't go far enough: We're *still* making an awful lot of useless calls. The problem is that each time we shrink the element at index i we try shrinking the elements at indexes 0 through i - 1, and this will never work. So what we want to do is to break shrinking elements into a separate shrinker for each index:

In [33]:
def simplify_index(i):
    def accept(ls):
        if i >= len(ls):
            return
        for v in shrink_integer(ls[i]):
            s = list(ls)
            s[i] = v
            yield s

    return accept


def shrinkers_for(ls):
    yield shrink_to_prefix
    yield delete_individual_elements
    yield replace_with_simpler
    yield shrink_shared
    for i in range(len(ls)):
        yield simplify_index(i)


def multicourse_shrink2(ls, constraint):
    seen = set()
    for shrink in shrinkers_for(ls):
        while True:
            for s in shrink(ls):
                key = tuple(s)
                if key in seen:
                    continue
                seen.add(key)
                if constraint(s):
                    ls = s
                    break
            else:
                break
    return ls

In [34]:
show_trace(
    [100 + i for i in range(10)], lambda x: len(set(x)) >= 10, multicourse_shrink2
)

✓ [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
✗ [100]
✗ [100, 101]
✗ [100, 101, 102, 103]
✗ [100, 101, 102, 103, 104, 105, 106, 107]
✗ [101, 102, 103, 104, 105, 106, 107, 108, 109]
✗ [100, 102, 103, 104, 105, 106, 107, 108, 109]
✗ [100, 101, 103, 104, 105, 106, 107, 108, 109]
✗ [100, 101, 102, 104, 105, 106, 107, 108, 109]
✗ [100, 101, 102, 103, 105, 106, 107, 108, 109]
✗ [100, 101, 102, 103, 104, 106, 107, 108, 109]
✗ [100, 101, 102, 103, 104, 105, 107, 108, 109]
✗ [100, 101, 102, 103, 104, 105, 106, 108, 109]
✗ [100, 101, 102, 103, 104, 105, 106, 107, 109]
✗ [100, 101, 102, 103, 104, 105, 106, 107, 108]
✗ [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
✗ [100, 101, 101, 101, 101, 101, 101, 101, 101, 101]
✗ [100, 101, 102, 102, 102, 102, 102, 102, 102, 102]
✗ [100, 101, 102, 103, 103, 103, 103, 103, 103, 103]
✗ [100, 101, 102, 103, 104, 104, 104, 104, 104, 104]
✗ [100, 101, 102, 103, 104, 105, 105, 105, 105, 105]
✗ [100, 101, 102, 103, 104, 105, 106, 106, 106, 106]
✗ [100, 1

This worked great! It saved us a huge number of function calls.

Unfortunately it's wrong. Actually the previous one was wrong too, but this one is more obviously wrong. The problem is that shrinking later elements can unlock more shrinks for earlier elements and we'll never be able to benefit from that here:

In [35]:
show_trace([101, 100], lambda x: len(x) >= 2 and x[0] > x[1], multicourse_shrink2)

✓ [101, 100]
✗ [101]
✗ [100]
✗ [100, 100]
✗ [0, 100]
✗ [1, 100]
✗ [3, 100]
✗ [7, 100]
✗ [15, 100]
✗ [31, 100]
✗ [63, 100]
✗ [82, 100]
✗ [91, 100]
✗ [96, 100]
✗ [98, 100]
✗ [99, 100]
✓ [101, 0]

1 shrinks with 16 function calls


Armed with this example we can also show an example where the previous one is wrong because a later simplification unlocks an earlier one because shrinking values allows us to delete more elements:

In [36]:
show_trace([5] * 10, lambda x: x and len(x) > max(x), multicourse_shrink1)

✓ [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
✗ [5]
✗ [5, 5]
✗ [5, 5, 5, 5]
✓ [5, 5, 5, 5, 5, 5, 5, 5]
✓ [0, 0, 0, 0, 0, 0, 0, 0]

2 shrinks with 5 function calls


In [37]:
conditions["First > Second"] = lambda xs: len(xs) >= 2 and xs[0] > xs[1]

In [38]:
# Note: We modify this to mask off the high bits because otherwise the probability of
# hitting the condition at random is too low.
conditions["Size > max & 63"] = lambda xs: xs and len(xs) > (max(xs) & 63)

So what we'll try doing is iterating this to a fixed point and see what happens:

In [39]:
def multicourse_shrink3(ls, constraint):
    seen = set()
    while True:
        old_ls = ls
        for shrink in shrinkers_for(ls):
            while True:
                for s in shrink(ls):
                    key = tuple(s)
                    if key in seen:
                        continue
                    seen.add(key)
                    if constraint(s):
                        ls = s
                        break
                else:
                    break
        if ls == old_ls:
            return ls

In [40]:
show_trace([101, 100], lambda xs: len(xs) >= 2 and xs[0] > xs[1], multicourse_shrink3)

✓ [101, 100]
✗ [101]
✗ [100]
✗ [100, 100]
✗ [0, 100]
✗ [1, 100]
✗ [3, 100]
✗ [7, 100]
✗ [15, 100]
✗ [31, 100]
✗ [63, 100]
✗ [82, 100]
✗ [91, 100]
✗ [96, 100]
✗ [98, 100]
✗ [99, 100]
✓ [101, 0]
✗ [0]
✗ [0, 0]
✓ [1, 0]
✗ [1]

2 shrinks with 20 function calls


In [41]:
show_trace([5] * 10, lambda x: x and len(x) > max(x), multicourse_shrink3)

✓ [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
✗ [5]
✗ [5, 5]
✗ [5, 5, 5, 5]
✓ [5, 5, 5, 5, 5, 5, 5, 5]
✓ [5, 5, 5, 5, 5, 5, 5]
✓ [5, 5, 5, 5, 5, 5]
✗ [5, 5, 5, 5, 5]
✓ [0, 0, 0, 0, 0, 0]
✓ [0]
✗ []

5 shrinks with 10 function calls


So that worked. Yay!

Lets compare how this does to our single pass implementation.

In [42]:
compare_simplifiers(
    [
        ("Single pass", partial(greedy_shrink_with_dedupe, shrink=shrink6)),
        ("Multi pass", multicourse_shrink3),
    ]
)

Condition,Single pass,Multi pass
length >= 2,6,6
sum >= 500,35,35
sum >= 3,6,6
At least 10 by 5,107,73
10 distinct elements,623,131
First > Second,1481,1445
Size > max & 63,600,> 5000


So the answer is generally favourably but *ouch* that last one.

What's happening there is that because later shrinks are opening up potentially very large improvements accessible to the lower shrinks, the original greedy algorithm can exploit that much better, while the multi pass algorithm spends a lot of time in the later stages with their incremental shrinks.

Lets see another similar example before we try to fix this:

In [43]:
import hashlib

conditions["Messy"] = (
    lambda xs: hashlib.md5(repr(xs).encode("utf-8")).hexdigest()[0] == "0"
)

In [44]:
compare_simplifiers(
    [
        ("Single pass", partial(greedy_shrink_with_dedupe, shrink=shrink6)),
        ("Multi pass", multicourse_shrink3),
    ]
)

Condition,Single pass,Multi pass
length >= 2,6,6
sum >= 500,35,35
sum >= 3,6,6
At least 10 by 5,107,73
10 distinct elements,623,131
First > Second,1481,1445
Size > max & 63,600,> 5000
Messy,1032,> 5000


This one is a bit different in that the problem is not that the structure is one we're ill suited to exploiting, it's that there is no structure at all so we have no hope of exploiting it. Literally any change at all will unlock earlier shrinks we could have done.

What we're going to try to do is hybridize the two approaches. If we notice we're performing an awful lot of shrinks we can take that as a hint that we should be trying again from earlier stages.

Here is our first approach. We simply restart the whole process every five shrinks:

In [45]:
MAX_SHRINKS_PER_RUN = 2


def multicourse_shrink4(ls, constraint):
    seen = set()
    while True:
        old_ls = ls
        shrinks_this_run = 0
        for shrink in shrinkers_for(ls):
            while shrinks_this_run < MAX_SHRINKS_PER_RUN:
                for s in shrink(ls):
                    key = tuple(s)
                    if key in seen:
                        continue
                    seen.add(key)
                    if constraint(s):
                        shrinks_this_run += 1
                        ls = s
                        break
                else:
                    break
        if ls == old_ls:
            return ls

In [46]:
compare_simplifiers(
    [
        ("Single pass", partial(greedy_shrink_with_dedupe, shrink=shrink6)),
        ("Multi pass", multicourse_shrink3),
        ("Multi pass with restart", multicourse_shrink4),
    ]
)

Condition,Single pass,Multi pass,Multi pass with restart
length >= 2,6,6,6
sum >= 500,35,35,35
sum >= 3,6,6,6
At least 10 by 5,107,73,90
10 distinct elements,623,131,396
First > Second,1481,1445,1463
Size > max & 63,600,> 5000,> 5000
Messy,1032,> 5000,1423


That works OK, but it's pretty unsatisfying as it loses us most of the benefits of the multi pass shrinking - we're now at most twice as good as the greedy one.

So what we're going to do is bet on the multi pass working and then gradually degrade to the greedy algorithm as it fails to work.

In [47]:
def multicourse_shrink5(ls, constraint):
    seen = set()
    max_shrinks_per_run = 10
    while True:
        shrinks_this_run = 0
        for shrink in shrinkers_for(ls):
            while shrinks_this_run < max_shrinks_per_run:
                for s in shrink(ls):
                    key = tuple(s)
                    if key in seen:
                        continue
                    seen.add(key)
                    if constraint(s):
                        shrinks_this_run += 1
                        ls = s
                        break
                else:
                    break
        if max_shrinks_per_run > 1:
            max_shrinks_per_run -= 2
        if not shrinks_this_run:
            return ls

In [48]:
show_trace([5] * 10, lambda x: x and len(x) > max(x), multicourse_shrink5)

✓ [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
✗ [5]
✗ [5, 5]
✗ [5, 5, 5, 5]
✓ [5, 5, 5, 5, 5, 5, 5, 5]
✓ [5, 5, 5, 5, 5, 5, 5]
✓ [5, 5, 5, 5, 5, 5]
✗ [5, 5, 5, 5, 5]
✓ [0, 0, 0, 0, 0, 0]
✓ [0]
✗ []

5 shrinks with 10 function calls


In [49]:
compare_simplifiers(
    [
        ("Single pass", partial(greedy_shrink_with_dedupe, shrink=shrink6)),
        ("Multi pass", multicourse_shrink3),
        ("Multi pass with restart", multicourse_shrink4),
        ("Multi pass with variable restart", multicourse_shrink5),
    ]
)

Condition,Single pass,Multi pass,Multi pass with restart,Multi pass with variable restart
length >= 2,6,6,6,6
sum >= 500,35,35,35,35
sum >= 3,6,6,6,6
At least 10 by 5,107,73,90,73
10 distinct elements,623,131,396,212
First > Second,1481,1445,1463,1168
Size > max & 63,600,> 5000,> 5000,1002
Messy,1032,> 5000,1423,824


This is now more or less the current state of the art (it's actually a bit different from the Hypothesis state of the art at the time of this writing. I'm planning to merge some of the things I figured out in the course of writing this back in). We've got something that is able to adaptively take advantage of structure where it is present, but degrades reasonably gracefully back to the more aggressive version that works better in unstructured examples.

Surprisingly, on some examples it seems to even be best of all of them. I think that's more coincidence than truth though.