File: test_itemsets_with_ids.py

package info (click to toggle)
python-efficient-apriori 2.0.5-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,736 kB
  • sloc: python: 889; sh: 10; makefile: 10
file content (119 lines) | stat: -rw-r--r-- 3,597 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Tests for algorithms related to association rules.
"""

import pytest
import itertools
import random

from efficient_apriori.itemsets import itemsets_from_transactions, ItemsetCount


def generate_transactions(num_transactions, unique_items, items_row=(1, 100), seed=None):
    """
    Generate synthetic transactions.
    """
    if seed:
        random.seed(seed)
    else:
        random.seed()

    items = list(range(unique_items))

    for _ in range(num_transactions):
        items_this_row = random.randint(*items_row)
        yield random.sample(items, k=min(unique_items, items_this_row))


def itemsets_from_transactions_naive(transactions, min_support):
    """
    Naive algorithm used for testing only.
    """

    # Get the unique items from every transaction
    unique_items = {k for ts in transactions for k in ts}
    num_transactions = len(transactions)

    # Create an output dictionary
    L = dict()

    # For every possible combination length
    for k in range(1, len(unique_items) + 1):
        # For every possible combination
        for combination in itertools.combinations(unique_items, k):
            # Naively count how many transactions contain the combination
            counts = ItemsetCount()
            for i, t in enumerate(transactions):
                if set.issubset(set(combination), set(t)):
                    counts.itemset_count += 1
                    counts.members.add(i)

            # If the count exceeds the minimum support, add it
            if (counts.itemset_count / num_transactions) >= min_support:
                try:
                    L[k][tuple(sorted(list(combination)))] = counts
                except KeyError:
                    L[k] = dict()
                    L[k][tuple(sorted(list(combination)))] = counts

        try:
            L[k] = {k: v for (k, v) in sorted(L[k].items())}
            if L[k] == {}:
                del L[k]
                return L, num_transactions
        except KeyError:
            return L, num_transactions

    return L, num_transactions


input_data = [
    (
        list(
            generate_transactions(
                random.randint(5, 25),
                random.randint(1, 8),
                (1, random.randint(2, 8)),
            )
        ),
        random.randint(1, 4) / 10,
    )
    for i in range(500)
]


@pytest.mark.parametrize("transactions, min_support", input_data)
def test_itemsets_from_transactions_stochastic(transactions, min_support):
    """
    Test random inputs.
    """
    result, _ = itemsets_from_transactions(list(transactions), min_support, output_transaction_ids=True)
    naive_result, _ = itemsets_from_transactions_naive(list(transactions), min_support)

    for key in set.union(set(result.keys()), set(naive_result.keys())):
        assert result[key] == naive_result[key]


@pytest.mark.parametrize("transactions, min_support", input_data)
def test_itemsets_max_length(transactions, min_support):
    """
    The that nothing larger than max length is returned.
    """
    max_len = random.randint(1, 5)
    result, _ = itemsets_from_transactions(
        list(transactions),
        min_support,
        max_length=max_len,
        output_transaction_ids=True,
    )

    assert all(list(k <= max_len for k in result.keys()))
    for length, itemsets in result.items():
        for itemset_count in itemsets.values():
            assert all(isinstance(i, int) for i in itemset_count.members)


if __name__ == "__main__":
    pytest.main(args=[".", "--doctest-modules", "-v"])