File: ParDenseMatrix.h

package info (click to toggle)
lammps 20220106.git7586adbb6a%2Bds1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 348,064 kB
  • sloc: cpp: 831,421; python: 24,896; xml: 14,949; f90: 10,845; ansic: 7,967; sh: 4,226; perl: 4,064; fortran: 2,424; makefile: 1,501; objc: 238; lisp: 163; csh: 16; awk: 14; tcl: 6
file content (172 lines) | stat: -rw-r--r-- 5,117 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#ifndef PARDENSEMATRIX_H
#define PARDENSEMATRIX_H

#include "MatrixDef.h"
#include "DenseMatrix.h"
#include "DenseVector.h"
#include "MPI_Wrappers.h"

#include "ATC_Error.h"
using ATC::ATC_Error;

#include <algorithm>
#include <sstream>

namespace ATC_matrix {

  /**
   *  @class  ParDenseMatrix
   *  @brief  Parallelized version of DenseMatrix class.
   */

  template <typename T>
  class ParDenseMatrix : public DenseMatrix<T> {
  public:
    MPI_Comm _comm;

    ParDenseMatrix(MPI_Comm comm, INDEX rows=0, INDEX cols=0, bool z=1)
      : DenseMatrix<T>(rows, cols, z), _comm(comm) {}
    ParDenseMatrix(MPI_Comm comm, const DenseMatrix<T>& c)
      : DenseMatrix<T>(c), _comm(comm) {}
    ParDenseMatrix(MPI_Comm comm, const SparseMatrix<T>& c)
      : DenseMatrix<T>(c), _comm(comm) {}
    ParDenseMatrix(MPI_Comm comm, const Matrix<T>& c)
      : DenseMatrix<T>(c), _comm(comm) {}

    //////////////////////////////////////////////////////////////////////////////
    //* performs a matrix-vector multiply
    void ParMultMv(const Vector<T> &v,
             DenseVector<T> &c, const bool At, T a, T b)
    {
      // We can't generically support parallel multiplication because the data
      // types must be specified when using MPI
      MultMv(*this, v, c, At, a, b);
    }

  };

  template<>
  class ParDenseMatrix<double> : public DenseMatrix<double> {
  public:
    MPI_Comm _comm;

    ParDenseMatrix(MPI_Comm comm, INDEX rows=0, INDEX cols=0, bool z=1)
      : DenseMatrix<double>(rows, cols, z), _comm(comm) {}
    ParDenseMatrix(MPI_Comm comm, const DenseMatrix<double>& c)
      : DenseMatrix<double>(c), _comm(comm) {}
    ParDenseMatrix(MPI_Comm comm, const SparseMatrix<double>& c)
      : DenseMatrix<double>(c), _comm(comm) {}
    ParDenseMatrix(MPI_Comm comm, const Matrix<double>& c)
      : DenseMatrix<double>(c), _comm(comm) {}


    void ParMultMv(const Vector<double> &v, DenseVector<double> &c,
        const bool At, double a, double b) const
    {
      // We don't support parallel vec-Mat multiplication yet
      if (At) {
        MultMv(*this, v, c, At, a, b);
        return;
      }

      const INDEX nRows = this->nRows();
      const INDEX nCols = this->nCols();

      if (c.size() != nRows) {
        c.resize(nRows);             // set size of C
        c.zero();                // do not add result to C
      } else c *= b;

      // Determine how many rows will be handled on each processor
      int nProcs = MPI_Wrappers::size(_comm);
      int myRank = MPI_Wrappers::rank(_comm);



      int *majorCounts = new int[nProcs];
      int *offsets = new int[nProcs];

#ifdef COL_STORAGE // Column-major storage
      int nMajor = nCols;
      int nMinor = nRows;
      int ParDenseMatrix::*majorField = &ParDenseMatrix::_nCols;
      int ParDenseMatrix::*minorField = &ParDenseMatrix::_nRows;
#else // Row-major storage
      int nMajor = nRows;
      int nMinor = nCols;
      int ParDenseMatrix::*majorField = &ParDenseMatrix::_nRows;
      int ParDenseMatrix::*minorField = &ParDenseMatrix::_nCols;
#endif

      for (int i = 0; i < nProcs; i++) {
        // If we have an uneven row-or-col/processors number, or too few rows
        // or cols, some processors will need to receive fewer rows/cols.
        offsets[i] = (i * nMajor) / nProcs;
        majorCounts[i] = (((i + 1) * nMajor) / nProcs) - offsets[i];
      }

      int myNMajor = majorCounts[myRank];
      int myMajorOffset = offsets[myRank];

      // Take data from an offset version of A
      ParDenseMatrix<double> A_local(_comm);
      A_local._data = this->_data + myMajorOffset * nMinor;
      A_local.*majorField = myNMajor;
      A_local.*minorField = nMinor;

#ifdef COL_STORAGE // Column-major storage

      // When splitting by columns, we split the vector as well, and sum the
      // results.

      DenseVector<double> v_local(myNMajor);
      for (int i = 0; i < myNMajor; i++)
        v_local(i) = v(myMajorOffset + i);

      // Store results in a local vector
      DenseVector<double> c_local = A_local * v_local;

      // Sum all vectors onto each processor
      MPI_Wrappers::allsum(_comm, c_local.ptr(), c.ptr(), c_local.size());

#else // Row-major storage

      // When splitting by rows, we use the whole vector and concatenate the
      // results.

      // Store results in a small local vector
      DenseVector<double> c_local(myNMajor);
      for (int i = 0; i < myNMajor; i++)
        c_local(i) = c(myMajorOffset + i);

      MultMv(A_local, v, c_local, At, a, b);

      // Gather the results onto each processor
      allgatherv(_comm, c_local.ptr(), c_local.size(), c.ptr(),
                 majorCounts, offsets);

#endif

      // Clear out the local matrix's pointer so we don't double-free
      A_local._data = nullptr;

      delete [] majorCounts;
      delete [] offsets;

    }

  };

  // Operator for dense Matrix - dense vector product
  template<typename T>
  DenseVector<T> operator*(const ParDenseMatrix<T> &A, const Vector<T> &b)
  {
    DenseVector<T> c;
    A.ParMultMv(b, c, 0, 1.0, 0.0);
    return c;
  }


} // end namespace
#endif