1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
#include <blas.hh>
#include <vector>
#include <stdio.h>
#include "util.hh"
//------------------------------------------------------------------------------
template <typename T>
void test_gemm( int m, int n, int k )
{
print_func();
int lda = m;
int ldb = k;
int ldc = m;
std::vector<T> A( lda*k, 1.0 ); // m-by-k
std::vector<T> B( ldb*n, 2.0 ); // k-by-n
std::vector<T> C( ldc*n, 3.0 ); // m-by-n
// ... fill in application data into A, B, C ...
// C = -1.0*A*B + 1.0*C
blas::gemm( blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans,
m, n, k,
-1.0, A.data(), lda,
B.data(), ldb,
1.0, C.data(), ldc );
}
//------------------------------------------------------------------------------
template <typename T>
void test_device_gemm( int m, int n, int k )
{
print_func();
if (blas::get_device_count() == 0) {
printf( "no GPU devices\n" );
}
else {
int lda = m;
int ldb = k;
int ldc = m;
std::vector<T> A( lda*k, 1.0 ); // m-by-k
std::vector<T> B( ldb*n, 2.0 ); // k-by-n
std::vector<T> C( ldc*n, 3.0 ); // m-by-n
// ... fill in application data into A, B, C ...
int device = 0;
blas::Queue queue( device );
T *dA = blas::device_malloc<T>( lda*k, queue ); // m-by-k
T *dB = blas::device_malloc<T>( ldb*n, queue ); // k-by-n
T *dC = blas::device_malloc<T>( ldc*n, queue ); // m-by-n
blas::device_copy_matrix(
m, k,
A.data(), lda, // src
dA, lda, queue ); // dst
blas::device_copy_matrix(
k, n,
B.data(), ldb, // src
dB, ldb, queue ); // dst
blas::device_copy_matrix(
m, n,
C.data(), ldc, // src
dC, ldc, queue ); // dst
// C = -1.0*A*B + 1.0*C
blas::gemm(
blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans,
m, n, k,
-1.0, dA, lda,
dB, ldb,
1.0, dC, ldc,
queue );
blas::device_copy_matrix(
m, n,
dC, ldc, // src
C.data(), ldc, queue ); // dst
queue.sync();
blas::device_free( dA, queue ); dA = nullptr;
blas::device_free( dB, queue ); dB = nullptr;
blas::device_free( dC, queue ); dC = nullptr;
}
}
//------------------------------------------------------------------------------
int main( int argc, char** argv )
{
try {
// Parse command line to set types for s, d, c, z precisions.
bool types[ 4 ];
parse_args( argc, argv, types );
int m = 100, n = 200, k = 50;
printf( "m %d, n %d, k %d\n", m, n, k );
// Run tests.
if (types[ 0 ])
test_gemm< float >( m, n, k );
if (types[ 1 ])
test_gemm< double >( m, n, k );
if (types[ 2 ])
test_gemm< std::complex<float> >( m, n, k );
if (types[ 3 ])
test_gemm< std::complex<double> >( m, n, k );
if (types[ 0 ])
test_device_gemm< float >( m, n, k );
if (types[ 1 ])
test_device_gemm< double >( m, n, k );
if (types[ 2 ])
test_device_gemm< std::complex<float> >( m, n, k );
if (types[ 3 ])
test_device_gemm< std::complex<double> >( m, n, k );
}
catch (std::exception const& ex) {
fprintf( stderr, "%s\n", ex.what() );
return 1;
}
return 0;
}
|