File: test_concurrency.py

package info (click to toggle)
simplebayes 3.2.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 496 kB
  • sloc: python: 3,322; makefile: 165; sh: 24
file content (49 lines) | stat: -rw-r--r-- 1,696 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from concurrent.futures import ThreadPoolExecutor

from simplebayes import SimpleBayes


def test_parallel_train_and_score_completes():
    classifier = SimpleBayes()

    def train_and_score(index: int) -> None:
        classifier.train("tech", f"python fastapi service sample {index}")
        _ = classifier.score("python service")

    with ThreadPoolExecutor(max_workers=8) as pool:
        list(pool.map(train_and_score, range(50)))

    result = classifier.classify_result("python service")
    assert result.category == "tech"
    assert result.score > 0
    summaries = classifier.get_summaries()
    assert summaries["tech"].token_tally == 250
    assert classifier.tally("tech") == 250


def test_parallel_classify_during_mutation():
    classifier = SimpleBayes()
    classifier.train("alpha", "one two three")
    classifier.train("beta", "four five six")

    def mutate() -> None:
        for _ in range(50):
            classifier.train("alpha", "one two three")
            classifier.untrain("alpha", "one")

    def classify() -> None:
        for _ in range(50):
            _ = classifier.classify("one five")
            _ = classifier.get_summaries()

    with ThreadPoolExecutor(max_workers=4) as pool:
        futures = [pool.submit(mutate), pool.submit(classify), pool.submit(classify)]
        for future in futures:
            future.result()

    assert classifier.tally("alpha") == 103
    assert classifier.tally("beta") == 3
    summaries = classifier.get_summaries()
    assert summaries["alpha"].token_tally == 103
    assert summaries["beta"].token_tally == 3
    assert abs((summaries["alpha"].prob_in_cat + summaries["alpha"].prob_not_in_cat) - 1.0) < 1e-12