File: TestDeviceAndThreads.py

package info (click to toggle)
kokkos 4.7.01-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 16,636 kB
  • sloc: cpp: 223,676; sh: 2,446; makefile: 2,437; python: 91; fortran: 4; ansic: 2
file content (121 lines) | stat: -rw-r--r-- 4,046 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#@HEADER
# ************************************************************************

#                        Kokkos v. 4.0
#       Copyright (2022) National Technology & Engineering
#               Solutions of Sandia, LLC (NTESS).

# Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.

# Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.

# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# ************************************************************************
# @HEADER

import unittest
import subprocess
import platform
import os

PREFIX = "$<TARGET_FILE_DIR:Kokkos_CoreUnitTest_DeviceAndThreads>"
EXECUTABLE = "$<TARGET_FILE_NAME:Kokkos_CoreUnitTest_DeviceAndThreads>"
COMMAND = "/".join([PREFIX, EXECUTABLE])


def GetFlag(flag, *extra_args):
    p = subprocess.run([COMMAND, flag, *extra_args], capture_output=True)
    if p.returncode != 0:
        raise Exception(p.stderr.decode("utf-8"))
    return int(p.stdout)

def GetNumThreads(max_threads):
    args = []
    name = platform.system()
    if name == 'Darwin':
        args = ['sysctl', '-n', 'hw.physicalcpu_max']
    elif name == 'Linux':
        args = ['nproc', '--all']
    else:
        args = ['wmic', 'cpu', 'get', 'NumberOfCores']

    result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output = result.stdout.decode('utf-8')
    phys_cores_count = int(output)
    looplist = [1] + [i*phys_cores_count for i in [1,2,3,4,5,6,7]] \
        if GetFlag("hwloc_enabled") else [1,2,3,4,5]

    for x in looplist:
        if x >= max_threads:
            break
        yield x
    yield max_threads

class KokkosInitializationTestCase(unittest.TestCase):
    def test_num_threads(self):
        max_threads = GetFlag("max_threads")
        if max_threads == 1:
            self.skipTest("no host parallel backend enabled")
        for num_threads in GetNumThreads(max_threads):
            self.assertEqual(
                num_threads,
                GetFlag(
                    "num_threads",
                    "--kokkos-num-threads={}".format(num_threads)))

    def test_num_devices(self):
        if "KOKKOS_VISIBLE_DEVICES" in os.environ:
            self.skipTest("KOKKOS_VISIBLE_DEVICES environment variable is set")
        num_devices = GetFlag("num_devices")
        self.assertNotEqual(num_devices, 0)
        if num_devices == -1:
            self.skipTest("no device backend enabled")
        self.assertGreaterEqual(num_devices, 1)

    def test_device_id(self):
        if "KOKKOS_VISIBLE_DEVICES" in os.environ:
            self.skipTest("KOKKOS_VISIBLE_DEVICES environment variable is set")
        num_devices = GetFlag("num_devices")
        if num_devices == -1:
            self.assertEqual(-1, GetFlag("device_id"))
            self.skipTest("no device backend enabled")
        self.assertGreaterEqual(GetFlag("device_id"), 0)
        self.assertLessEqual(GetFlag("device_id"), num_devices)
        for device_id in range(num_devices):
            self.assertEqual(
                device_id,
                GetFlag(
                    "device_id",
                    "--kokkos-device-id={}".format(device_id)))

    def test_disable_warnings(self):
        self.assertEqual(0, GetFlag("disable_warnings"))
        self.assertEqual(
            0,
            GetFlag(
                "disable_warnings",
                "--kokkos-disable-warnings=0"))
        self.assertEqual(
            1,
            GetFlag(
                "disable_warnings",
                "--kokkos-disable-warnings=1"))

    def test_tune_internals(self):
        self.assertEqual(0, GetFlag("tune_internals"))
        self.assertEqual(
            0,
            GetFlag(
                "tune_internals",
                "--kokkos-tune-internals=0"))
        self.assertEqual(
            1,
            GetFlag(
                "tune_internals",
                "--kokkos-tune-internals=1"))


if __name__ == '__main__':
    unittest.main()