1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
#include <iostream>
#include <vector>
#include "CombBLAS/CombBLAS.h"
using namespace std;
using namespace combblas;
typedef int64_t IT;
typedef double NT;
int main(int argc, char* argv[])
{
int nprocs, myrank;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD,&nprocs);
MPI_Comm_rank(MPI_COMM_WORLD,&myrank);
if(argc < 4)
{
if(myrank == 0)
cout << "Usage: ./BlockedSpGEMM <MatrixA> <MatrixB> <br> <bc>" << endl;
MPI_Finalize();
return -1;
}
{
string Aname(argv[1]);
string Bname(argv[2]);
int br = atoi(argv[3]);
int bc = atoi(argv[4]);
MPI_Barrier(MPI_COMM_WORLD);
typedef PlusTimesSRing<NT, NT> SR_PT;
typedef SpDCCols<IT, NT> DER;
shared_ptr<CommGrid> fullWorld;
fullWorld.reset(new CommGrid(MPI_COMM_WORLD, 0, 0));
SpParMat<IT, NT, DER> A(fullWorld);
A.ParallelReadMM(Aname, true, maximum<NT>());
IT nr = A.getnrow(), nc = A.getncol(), nnz = A.getnnz();
if (myrank == 0)
cout << "A " << nr << " " << nc << " " << nnz << std::endl;
SpParMat<IT, NT, DER> B(fullWorld);
B.ParallelReadMM(Bname, true, maximum<NT>());
nr = B.getnrow(), nc = B.getncol(), nnz = B.getnnz();
if (myrank == 0)
cout << "B " << nr << " " << nc << " " << nnz << std::endl;
// auto blocks = A.BlockSplit(br, bc);
BlockSpGEMM<IT, NT, DER, NT, DER> bspgemm(A, B, br, bc);
IT roffset, coffset;
while (bspgemm.hasNext())
{
auto C = bspgemm.getNextBlock<SR_PT, NT, DER>(roffset, coffset);
nr = C.getnrow(), nc = C.getncol(), nnz = C.getnnz();
if (myrank == 0)
cout << "block size " << nr << " " << nc << " " << nnz
<< " offsets " << roffset << " " << coffset
<< std::endl;
}
// auto C = bspgemm.getBlockId<SR_PT, NT, DER>(0, 1, roffset, coffset);
// nr = C.getnrow(), nc = C.getncol(), nnz = C.getnnz();
// if (myrank == 0)
// cout << "block size " << nr << " " << nc << " " << nnz
// << " offsets " << roffset << " " << coffset
// << std::endl;
}
MPI_Finalize();
return 0;
}
|