1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
|
"""Test the behaviour of the ProcessPool class."""
import os
from typing import Iterable, Tuple
import pytest
from pytest_mock import MockerFixture
from cobra.util import ProcessPool
def dummy_initializer(*args: Iterable) -> Tuple:
"""Implement a 'do nothing' function that accepts initialization arguments."""
return args
def square(num: int) -> int:
"""Return the square of an integer."""
return num * num
def summation(*args: Iterable[int]) -> int:
"""Return the sum of all integer arguments."""
return sum(args)
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
@pytest.mark.parametrize(
"attributes",
[
{},
{"processes": 2},
{"initializer": dummy_initializer},
{"initializer": dummy_initializer, "initargs": (1, "2", [3], {"a": 4})},
{"maxtasksperchild": 1},
],
)
def test_init(attributes: dict) -> None:
"""Test that a process pool can be initialized with each of its arguments."""
ProcessPool(**attributes)
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_close(mocker: MockerFixture) -> None:
"""Test that the composed pool is closed as well."""
pool = ProcessPool(processes=3)
mock = mocker.patch.object(pool, "_pool", autospec=True)
pool.close()
mock.close.assert_called_once()
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_with_context(mocker: MockerFixture) -> None:
"""Test that the composed pool's context is managed as well."""
pool = ProcessPool(processes=3)
mock = mocker.patch.object(pool, "_pool", autospec=True)
with pool:
pass
mock.__enter__.assert_called_once()
mock.__exit__.assert_called_once()
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_apply() -> None:
"""Test that a function can be applied."""
with ProcessPool(processes=3) as pool:
assert pool.apply(square, (3,)) == 9
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_apply_async() -> None:
"""Test that a function can be applied asynchronously."""
with ProcessPool(processes=3) as pool:
assert pool.apply_async(square, (3,)).get() == 9
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_map() -> None:
"""Test that a function can be mapped over an iterable of values."""
with ProcessPool(processes=3) as pool:
assert sum(pool.map(square, [2] * 6)) == 24
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_map_async() -> None:
"""Test that a function can be mapped over an iterable of values asynchronously."""
with ProcessPool(processes=3) as pool:
assert sum(pool.map_async(square, [2] * 6).get()) == 24
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_imap() -> None:
"""Test that mapped function results can be iterated."""
with ProcessPool(processes=3) as pool:
total = 0
for result in pool.imap(square, [2] * 6):
total += result
assert total == 24
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_imap_unordered() -> None:
"""Test that mapped function results can be iterated in any order."""
with ProcessPool(processes=3) as pool:
assert sum(pool.imap_unordered(square, [2] * 6)) == 24
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_starmap() -> None:
"""Test that a function can be starmapped over many iterables."""
with ProcessPool(processes=3) as pool:
assert sum(pool.starmap(summation, [range(10), range(10), range(10)])) == 135
@pytest.mark.skipif("SKIP_MP" in os.environ, reason="unsafe for parallel execution")
def test_starmap_async() -> None:
"""Test that a function can be starmapped over many iterables asynchronously."""
with ProcessPool(processes=3) as pool:
assert (
sum(pool.starmap_async(summation, [range(10), range(10), range(10)]).get())
== 135
)
|