File: test_cuda_hook.c

package info (click to toggle)
mpich 4.0.2-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 423,384 kB
  • sloc: ansic: 1,088,434; cpp: 71,364; javascript: 40,763; f90: 22,829; sh: 17,463; perl: 14,773; xml: 14,418; python: 10,265; makefile: 9,246; fortran: 8,008; java: 4,355; asm: 324; ruby: 176; lisp: 19; php: 8; sed: 4
file content (130 lines) | stat: -rw-r--r-- 3,249 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/**
 * Copyright (C) Mellanox Technologies Ltd. 2020.  ALL RIGHTS RESERVED.
 *
 * See file LICENSE for terms.
 */

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <ucp/api/ucp.h>
#include <ucm/api/ucm.h>
#include <cuda_runtime.h>
#include <getopt.h>
#include <cuda.h>


static void event_cb(ucm_event_type_t event_type, ucm_event_t *event, void *arg)
{
    int *count_p = arg;
    const char *title;

    if (event_type == UCM_EVENT_MEM_TYPE_ALLOC) {
        title = "allocate";
    } else if (event_type == UCM_EVENT_MEM_TYPE_FREE) {
        title = "free";
    } else {
        printf("unexpected memory event type %d\n", event_type);
        return;
    }

    printf("%s %s address %p size %zu\n", title,
           ucs_memory_type_names[event->mem_type.mem_type],
           event->mem_type.address, event->mem_type.size);
    ++(*count_p);
}

static void alloc_driver_api()
{
    CUdeviceptr dptr = 0;
    CUcontext context;
    CUdevice device;
    CUresult res;

    res = cuInit(0);
    if (res != CUDA_SUCCESS) {
        printf("cuInit() failed: %d\n", res);
        return;
    }

    res = cuDeviceGet(&device, 0);
    if (res != CUDA_SUCCESS) {
        printf("cuDeviceGet(0) failed: %d\n", res);
        return;
    }

    res = cuCtxCreate(&context, 0, device);
    if (res != CUDA_SUCCESS) {
        printf("cuCtxCreate() failed: %d\n", res);
        return;
    }

    res = cuMemAlloc(&dptr, 4096);
    printf("cuMemAlloc() returned 0x%lx result %d\n", (uintptr_t)dptr, res);
    cuMemFree(dptr);

    cuCtxDetach(context);
}

static void alloc_runtime_api()
{
    void *dptr = NULL;
    cudaError_t res;

    res = cudaMalloc(&dptr, 4096);
    printf("cudaMalloc() returned %p result %d\n", dptr, res);
    cudaFree(dptr);
}

int main(int argc, char **argv)
{
    static const ucm_event_type_t memtype_events = UCM_EVENT_MEM_TYPE_ALLOC |
                                                   UCM_EVENT_MEM_TYPE_FREE;
    static const int num_expected_events         = 2;
    ucp_context_h context;
    ucs_status_t status;
    ucp_params_t params;
    int use_driver_api;
    int num_events;
    int c;

    use_driver_api = 0;
    while ((c = getopt(argc, argv, "d")) != -1) {
        switch (c) {
        case 'd':
            use_driver_api = 1;
            break;
        default:
            printf("Usage: test_cuda_hook [options]\n");
            printf("Options are:\n");
            printf("  -d :   Use Cuda driver API (Default: use runtime API)\n");
            printf("\n");
            return -1;
        }
    }

    params.field_mask = UCP_PARAM_FIELD_FEATURES;
    params.features   = UCP_FEATURE_TAG | UCP_FEATURE_STREAM;
    status            = ucp_init(&params, NULL, &context);
    if (status != UCS_OK) {
        printf("failed to create context\n");
        return -1;
    }

    num_events = 0;
    ucm_set_event_handler(memtype_events, 1000, event_cb, &num_events);

    if (use_driver_api) {
        alloc_driver_api();
    } else {
        alloc_runtime_api();
    }

    ucm_unset_event_handler(memtype_events, event_cb, &num_events);
    printf("got %d/%d memory events\n", num_events, num_expected_events);

    ucp_cleanup(context);

    return (num_events >= num_expected_events) ? 0 : -1;
}