1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
/*------------------------------------------------------------------------------------------------*
* Copyright (C) by the DBCSR developers group - All rights reserved *
* This file is part of the DBCSR library. *
* *
* For information on the license, see the LICENSE file. *
* For further information please visit https://dbcsr.cp2k.org *
* SPDX-License-Identifier: GPL-2.0+ *
*------------------------------------------------------------------------------------------------*/
#include <stdio.h>
#include <stdlib.h>
#include <vector>
#include <array>
#include "libsmm_acc_benchmark.h"
#include "libsmm_acc.h"
/****************************************************************************\
\brief Checks correctness of randomly selected libsmm_acc multiplication kernels
\****************************************************************************/
int main(int argc, char** argv) {
DBCSR_MARK_USED(argc);
DBCSR_MARK_USED(argv);
KernelLauncher launcher_mm = libsmm_acc_process_d;
char buffer[1000];
char * kernel_descr[1] = {buffer};
// Get all blocksizes available in libsmm_acc
std::vector<Triplet> libsmm_acc_triplets = {
[[UNITTEST_KERNELS_HERE]]
};
int n_triplets = libsmm_acc_triplets.size();
printf("# libsmm_acc has %d blocksizes for multiplication\n", n_triplets);
int max_m=0, max_n=0, max_k=0;
for (int i=0; i<n_triplets; i++) {
max_m = std::max(max_n, libsmm_acc_triplets[i][0]);
max_n = std::max(max_m, libsmm_acc_triplets[i][1]);
max_k = std::max(max_k, libsmm_acc_triplets[i][2]);
}
libsmm_acc_benchmark_t* handle;
libsmm_acc_benchmark_init(&handle, test, max_m, max_n, max_k);
int errors = 0;
for (int i=0; i<n_triplets; i++) {
int m = libsmm_acc_triplets[i][0];
int n = libsmm_acc_triplets[i][1];
int k = libsmm_acc_triplets[i][2];
sprintf(buffer, "%d x %d x %d", m, n, k);
errors += libsmm_acc_benchmark(handle, m, n, k, 1, &launcher_mm, kernel_descr);
}
libsmm_acc_benchmark_finalize(handle);
printf("# Done, found %d matrix-matrix multiplication errors.\n", errors);
return errors;
}
|