File: test_base.py

package info (click to toggle)
mdanalysis 2.10.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,696 kB
  • sloc: python: 92,135; ansic: 8,156; makefile: 215; sh: 138
file content (109 lines) | stat: -rw-r--r-- 3,543 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4 fileencoding=utf-8
#
# MDAnalysis --- https://www.mdanalysis.org
# Copyright (c) 2006-2017 The MDAnalysis Development Team and contributors
# (see the file AUTHORS for the full list of names)
#
# Released under the Lesser GNU Public Licence, v2.1 or any higher version
#
# Please cite your use of MDAnalysis in published work:
#
# R. J. Gowers, M. Linke, J. Barnoud, T. J. E. Reddy, M. N. Melo, S. L. Seyler,
# D. L. Dotson, J. Domanski, S. Buchoux, I. M. Kenney, and O. Beckstein.
# MDAnalysis: A Python package for the rapid analysis of molecular dynamics
# simulations. In S. Benthall and S. Rostrup editors, Proceedings of the 15th
# Python in Science Conference, pages 102-109, Austin, TX, 2016. SciPy.
#
# N. Michaud-Agrawal, E. J. Denning, T. B. Woolf, and O. Beckstein.
# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations.
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#
import numpy as np
import pytest
from numpy.testing import assert_equal
from threadpoolctl import threadpool_info

import MDAnalysis as mda
from MDAnalysisTests.datafiles import PSF, DCD
from MDAnalysis.transformations.base import TransformationBase


class DefaultTransformation(TransformationBase):
    """Default values for max_threads and parallelizable"""

    def __init__(self):
        super().__init__()

    def _transform(self, ts):
        self.runtime_info = threadpool_info()
        ts.positions = ts.positions + 1
        return ts


class NoTransform_Transformation(TransformationBase):
    """Default values for max_threads and parallelizable"""

    def __init__(self):
        super().__init__()


class CustomTransformation(TransformationBase):
    """Custom value for max_threads and parallelizable"""

    def __init__(self, max_threads=1, parallelizable=False):
        super().__init__(
            max_threads=max_threads, parallelizable=parallelizable
        )

    def _transform(self, ts):
        self.runtime_info = threadpool_info()
        ts.positions = ts.positions + 1
        return ts


@pytest.fixture(scope="module")
def u():
    return mda.Universe(PSF, DCD)


def test_default_value():
    new_trans = DefaultTransformation()
    assert new_trans.max_threads is None
    assert new_trans.parallelizable is True


def test_no_transform_function(u):
    new_trans = NoTransform_Transformation()
    with pytest.raises(NotImplementedError, match=r"Only implemented"):
        _ = new_trans._transform(u.trajectory.ts)


def test_custom_value():
    new_trans = CustomTransformation()
    assert new_trans.max_threads == 1
    assert new_trans.parallelizable is False


def test_setting_thread_limit_value():
    new_trans = CustomTransformation(max_threads=4)
    assert new_trans.max_threads == 4


def test_thread_limit_apply(u):
    default_thread_info = threadpool_info()
    default_num_thread_limit_list = [
        thread_info["num_threads"] for thread_info in default_thread_info
    ]

    new_trans = CustomTransformation(max_threads=2)
    _ = new_trans(u.trajectory.ts)
    for thread_info in new_trans.runtime_info:
        assert thread_info["num_threads"] == 2

    #  test the thread limit is only applied locally.
    new_thread_info = threadpool_info()
    new_num_thread_limit_list = [
        thread_info["num_threads"] for thread_info in new_thread_info
    ]
    assert_equal(default_num_thread_limit_list, new_num_thread_limit_list)