1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
# SPDX-License-Identifier: LGPL-2.1-or-later
from multiprocessing import Barrier, Process
import os
from threading import Condition, Thread
from drgn.helpers.linux.pid import (
find_pid,
find_task,
for_each_pid,
for_each_task,
for_each_task_in_group,
)
from tests.linux_kernel import LinuxKernelTestCase
class TestPid(LinuxKernelTestCase):
def test_find_pid(self):
pid = os.getpid()
self.assertEqual(find_pid(self.prog, pid).numbers[0].nr, pid)
def test_for_each_pid(self):
pid = os.getpid()
self.assertTrue(
any(
pid_struct.numbers[0].nr == pid
for pid_struct in for_each_pid(self.prog)
)
)
def test_find_task(self):
pid = os.getpid()
with open("/proc/self/comm", "rb") as f:
comm = f.read()[:-1]
task = find_task(self.prog, os.getpid())
self.assertEqual(task.pid, pid)
self.assertEqual(task.comm.string_(), comm)
def test_for_each_task(self):
NUM_PROCS = 12
barrier = Barrier(NUM_PROCS + 1)
try:
procs = [Process(target=barrier.wait) for _ in range(NUM_PROCS)]
for proc in procs:
proc.start()
pids = {task.pid.value_() for task in for_each_task(self.prog)}
for proc in procs:
self.assertIn(proc.pid, pids)
self.assertIn(os.getpid(), pids)
barrier.wait()
except BaseException:
barrier.abort()
for proc in procs:
proc.terminate()
raise
def test_for_each_task_in_group(self):
NUM_THREADS = 12
condition = Condition()
this_task = find_task(self.prog, os.getpid())
def thread_func():
with condition:
condition.wait()
try:
threads = [Thread(target=thread_func) for _ in range(NUM_THREADS)]
for thread in threads:
thread.start()
actual = {
t.pid.value_()
for t in for_each_task_in_group(this_task, include_self=False)
}
for thread in threads:
self.assertIn(thread.native_id, actual)
self.assertNotIn(os.getpid(), actual)
finally:
with condition:
condition.notify_all()
for thread in threads:
thread.join()
|