File: test_dmah.py

package info (click to toggle)
rdma-core 61.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 13,124 kB
  • sloc: ansic: 176,798; python: 15,496; sh: 2,742; perl: 1,465; makefile: 73
file content (137 lines) | stat: -rw-r--r-- 5,940 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# SPDX-License-Identifier: (GPL-2.0 OR Linux-OpenIB)
# Copyright (c) 2025 NVIDIA Corporation . All rights reserved. See COPYING file

from pyverbs.libibverbs_enums import ibv_access_flags, ibv_wr_opcode, ibv_odp_transport_cap_bits, \
    ibv_tph_mem_type
from tests.base import PyverbsAPITestCase, RCResources, RDMATestCase
from pyverbs.pyverbs_error import PyverbsError, PyverbsRDMAError
import pyverbs.device as d
from pyverbs.pd import PD
from pyverbs.mr import MREx, DMAHandle
from pyverbs.qp import QPAttr
import tests.utils as u


class DMAHandleTest(PyverbsAPITestCase):
    @u.skip_unsupported
    def test_dmah_with_mrex(self):
        """Verify DMAHandle can be used during MREx registration."""
        with d.Context(name=self.dev_name) as ctx:
            with PD(ctx) as pd:
                attr = u.create_dmah_init_attr()
                with DMAHandle(ctx, attr) as dmah:
                    length = u.get_mr_length()
                    access = ibv_access_flags.IBV_ACCESS_LOCAL_WRITE
                    with MREx(pd, length, access, dmah=dmah) as mr:
                        pass

    @u.skip_unsupported
    def test_dmah_invalid_ph(self):
        """Verify DMAHandle with invalid PH value, max ph value is 3."""
        with d.Context(name=self.dev_name) as ctx:
            with PD(ctx) as pd:
                attr = u.create_dmah_init_attr(ph=4)
                with self.assertRaises(PyverbsError):
                    DMAHandle(ctx, attr)

    @u.skip_unsupported
    def test_dmah_persistent_memory(self):
        """Attempt to create DMAHandle targeting persistent memory."""
        with d.Context(name=self.dev_name) as ctx:
            attr = u.create_dmah_init_attr(tph_mem_type=ibv_tph_mem_type.IBV_TPH_MEM_TYPE_PM)
            with DMAHandle(ctx, attr):
                pass

    @u.skip_unsupported
    def test_dmah_invalid_mem_type(self):
        """Pass an unsupported TPH memory-type and verify provider rejects it."""
        with d.Context(name=self.dev_name) as ctx:
            attr = u.create_dmah_init_attr(tph_mem_type=0xFE)
            with self.assertRaises(PyverbsError):
                DMAHandle(ctx, attr)

    @u.skip_unsupported
    def test_dmah_inval_cpu_id(self):
        """Attempt to create DMAHandle with invalid CPU ID (0xffff). Expect failure."""
        with d.Context(name=self.dev_name) as ctx:
            attr = u.create_dmah_init_attr(cpu_id=0xffff, ph=3)
            with self.assertRaises(PyverbsError):
                DMAHandle(ctx, attr)

    @u.skip_unsupported
    def test_dmah_mrex_odp_bad_flow(self):
        """Attempt to register ODP-capable MREx with DMAHandle.
        Expect failure since ODP isn't supported with DMAHandle."""
        with d.Context(name=self.dev_name) as ctx:
            # Check ODP support; skip if not available
            odp_cap = (ibv_odp_transport_cap_bits.IBV_ODP_SUPPORT_SEND |
                       ibv_odp_transport_cap_bits.IBV_ODP_SUPPORT_RECV)
            u.odp_supported(ctx, 'rc', odp_cap)
            with PD(ctx) as pd:
                attr = u.create_dmah_init_attr()
                dmah = DMAHandle(ctx, attr)
                with self.assertRaises(PyverbsError):
                    length = u.get_mr_length()
                    access = (ibv_access_flags.IBV_ACCESS_LOCAL_WRITE |
                              ibv_access_flags.IBV_ACCESS_ON_DEMAND)
                    MREx(pd, length, access, dmah=dmah)


class DmaHandleMRExRC(RCResources):
    """RC resource class that registers an MREx with a DMAHandle."""

    def __init__(self, dev_name, ib_port, gid_index,
                 mr_access=ibv_access_flags.IBV_ACCESS_LOCAL_WRITE, msg_size=1024):
        self.dmah = None
        self.mr_access = mr_access
        super().__init__(dev_name=dev_name, ib_port=ib_port, gid_index=gid_index,
                         msg_size=msg_size)

    def create_dmah(self):
        """Allocate a DMAHandle using the existing device Context."""
        attr = u.create_dmah_init_attr()
        self.dmah = DMAHandle(self.ctx, attr)

    @u.skip_unsupported
    def create_mr(self):
        self.create_dmah()
        self.mr = MREx(self.pd, self.msg_size, self.mr_access, dmah=self.dmah)

    def create_qp_attr(self):
        qp_attr = QPAttr(port_num=self.ib_port)
        qp_access = (ibv_access_flags.IBV_ACCESS_LOCAL_WRITE |
                     ibv_access_flags.IBV_ACCESS_REMOTE_WRITE |
                     ibv_access_flags.IBV_ACCESS_REMOTE_ATOMIC)
        qp_attr.qp_access_flags = qp_access
        return qp_attr


class DmaHandleTrafficTest(RDMATestCase):
    """Traffic tests for MREx + DMAHandle combinations."""

    def setUp(self):
        super().setUp()
        self.iters = 10
        self.server = None
        self.client = None
        self.traffic_args = None

    def test_dmah_mrex_rc_send(self):
        """Checks basic RC send/recv traffic with DMAHandle-registered MREx."""
        access = ibv_access_flags.IBV_ACCESS_LOCAL_WRITE
        self.create_players(DmaHandleMRExRC, mr_access=access, msg_size=1024)
        u.traffic(**self.traffic_args)

    def test_dmah_mrex_rc_rdma_write(self):
        """Validates RC RDMA Write traffic with DMAHandle & MREx."""
        access = ibv_access_flags.IBV_ACCESS_LOCAL_WRITE | ibv_access_flags.IBV_ACCESS_REMOTE_WRITE
        self.create_players(DmaHandleMRExRC, mr_access=access, msg_size=1024)
        u.rdma_traffic(**self.traffic_args, send_op=ibv_wr_opcode.IBV_WR_RDMA_WRITE)

    def test_dmah_mrex_rc_atomic(self):
        """Tests RC atomic fetch&add using a DMAHandle-backed MREx."""
        access = (ibv_access_flags.IBV_ACCESS_LOCAL_WRITE |
                  ibv_access_flags.IBV_ACCESS_REMOTE_ATOMIC |
                  ibv_access_flags.IBV_ACCESS_REMOTE_WRITE)
        self.create_players(DmaHandleMRExRC, mr_access=access, msg_size=8)
        u.atomic_traffic(**self.traffic_args, send_op=ibv_wr_opcode.IBV_WR_ATOMIC_FETCH_AND_ADD)