1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
#include <hip/hip_runtime_api.h> // for hip functions
#include <rocsolver/rocsolver.h> // for all the rocsolver C interfaces and type declarations
#include <stdio.h> // for printf
#include <stdlib.h> // for malloc
// Example: Compute the QR Factorization of a matrix asynchronously on the GPU using the hipGraph API
double *create_example_matrix(rocblas_int *M_out,
rocblas_int *N_out,
rocblas_int *lda_out) {
// a *very* small example input; not a very efficient use of the API
const double A[3][3] = { { 12, -51, 4},
{ 6, 167, -68},
{ -4, 24, -41} };
const rocblas_int M = 3;
const rocblas_int N = 3;
const rocblas_int lda = 3;
*M_out = M;
*N_out = N;
*lda_out = lda;
// note: rocsolver matrices must be stored in column major format,
// i.e. entry (i,j) should be accessed by hA[i + j*lda]
double *hA = (double*)malloc(sizeof(double)*lda*N);
for (size_t i = 0; i < M; ++i) {
for (size_t j = 0; j < N; ++j) {
// copy A (2D array) into hA (1D array, column-major)
hA[i + j*lda] = A[i][j];
}
}
return hA;
}
// We use rocsolver_dgeqrf to factor a real M-by-N matrix, A.
// See https://rocm.docs.amd.com/projects/rocSOLVER/en/latest/api/lapack.html#rocsolver-type-geqrf
int main() {
const rocblas_int ITER_COUNT = 10;
rocblas_int M; // rows
rocblas_int N; // cols
rocblas_int lda; // leading dimension
double *hA = create_example_matrix(&M, &N, &lda); // input matrix on CPU
// let's print the input matrix, just to see it
printf("A = [\n");
for (size_t i = 0; i < M; ++i) {
printf(" ");
for (size_t j = 0; j < N; ++j) {
printf("% .3f ", hA[i + j*lda]);
}
printf(";\n");
}
printf("]\n");
// initialization
rocblas_handle handle;
rocblas_create_handle(&handle);
// Some rocsolver functions may trigger rocblas to load its GEMM kernels.
// You can preload the kernels by explicitly invoking rocblas_initialize
// (e.g., to exclude one-time initialization overhead from benchmarking).
// preload rocBLAS GEMM kernels (optional)
// rocblas_initialize();
// calculate the sizes of our arrays
size_t size_A = lda * (size_t)N; // count of elements in matrix A
size_t size_piv = (M < N) ? M : N; // count of Householder scalars
// allocate memory on GPU
double *dA, *dIpiv;
hipMalloc((void**)&dA, sizeof(double)*size_A);
hipMalloc((void**)&dIpiv, sizeof(double)*size_piv);
// copy data to GPU
hipMemcpy(dA, hA, sizeof(double)*size_A, hipMemcpyHostToDevice);
// compute the QR factorization on the GPU
// create the stream object
hipStream_t stream;
hipStreamCreate(&stream);
rocblas_set_stream(handle, stream);
// create graph management objects
hipGraph_t graph;
rocblas_int graph_ready = 0;
hipGraphExec_t exec;
for (int i = 0; i < ITER_COUNT; i++) {
if (!graph_ready) {
hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal);
rocsolver_dgeqrf(handle, M, N, dA, lda, dIpiv); // returns immediately
hipStreamEndCapture(stream, &graph);
hipGraphInstantiate(&exec, graph, NULL, NULL, 0);
hipGraphDestroy(graph);
graph_ready = 1;
}
hipGraphLaunch(exec, stream);
}
// copy the results back to CPU
double *hIpiv = (double*)malloc(sizeof(double)*size_piv);
hipMemcpy(hA, dA, sizeof(double)*size_A, hipMemcpyDeviceToHost); // will block until the stream is completed
hipMemcpy(hIpiv, dIpiv, sizeof(double)*size_piv, hipMemcpyDeviceToHost);
// the results are now in hA and hIpiv
// we can print some of the results if we want to see them
printf("R = [\n");
for (size_t i = 0; i < M; ++i) {
printf(" ");
for (size_t j = 0; j < N; ++j) {
printf("% .3f ", (i <= j) ? hA[i + j*lda] : 0);
}
printf(";\n");
}
printf("]\n");
// clean up
free(hIpiv);
hipFree(dA);
hipFree(dIpiv);
free(hA);
hipGraphExecDestroy(exec);
rocblas_destroy_handle(handle);
// order matters: the handle must be destroyed before the stream
hipStreamDestroy(stream);
}
|