File: svdmitigater.cpp

package info (click to toggle)
aoflagger 3.1.0-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 4,868 kB
  • sloc: cpp: 52,164; python: 152; sh: 60; makefile: 17
file content (172 lines) | stat: -rw-r--r-- 5,659 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include "../../util/stopwatch.h"

#include "svdmitigater.h"

#ifdef HAVE_GTKMM
#include "../../plot/plot2d.h"
#endif

extern "C" {
int zgesvd_(char *jobu, char *jobvt, integer *m, integer *n, doublecomplex *a,
            integer *lda, doublereal *s, doublecomplex *u, integer *ldu,
            doublecomplex *vt, integer *ldvt, doublecomplex *work,
            integer *lwork, doublereal *rwork, integer *info);
}

SVDMitigater::SVDMitigater()
    : _background(0),
      _singularValues(0),
      _leftSingularVectors(0),
      _rightSingularVectors(0),
      _m(0),
      _n(0),
      _iteration(0),
      _removeCount(10),
      _verbose(false) {}

SVDMitigater::~SVDMitigater() { Clear(); }

void SVDMitigater::Clear() {
  if (IsDecomposed()) {
    delete[] _singularValues;
    delete[] _leftSingularVectors;
    delete[] _rightSingularVectors;
    _singularValues = 0;
    _leftSingularVectors = 0;
    _rightSingularVectors = 0;
  }
  if (_background != 0) delete _background;
}

// lda = leading dimension

/*int cgebrd_(integer *m, integer *n, complex *a, integer *lda,
         real *d__, real *e, complex *tauq, complex *taup, complex *work,
        integer *lwork, integer *info);*/

void SVDMitigater::Decompose() {
  if (_verbose) std::cout << "Decomposing..." << std::endl;
  Stopwatch watch;
  watch.Start();
  Clear();

  // Remember that the axes have to be turned; in 'a', time is along the
  // vertical axis.
  _m = _data.ImageHeight();
  _n = _data.ImageWidth();
  int minmn = _m < _n ? _m : _n;
  char rowsOfU = 'A';   // all rows of u
  char rowsOfVT = 'A';  // all rows of VT
  doublecomplex *a = new doublecomplex[_m * _n];
  Image2DCPtr real = _data.GetRealPart(), imaginary = _data.GetImaginaryPart();
  for (int t = 0; t < _n; ++t) {
    for (int f = 0; f < _m; ++f) {
      a[t * _m + f].r = real->Value(t, f);
      a[t * _m + f].i = imaginary->Value(t, f);
    }
  }
  long int lda = _m;
  _singularValues = new double[minmn];
  for (int i = 0; i < minmn; ++i) _singularValues[i] = 0.0;
  _leftSingularVectors = new doublecomplex[_m * _m];
  for (int i = 0; i < _m * _m; ++i) {
    _leftSingularVectors[i].r = 0.0;
    _leftSingularVectors[i].i = 0.0;
  }
  _rightSingularVectors = new doublecomplex[_n * _n];
  for (int i = 0; i < _n * _n; ++i) {
    _rightSingularVectors[i].r = 0.0;
    _rightSingularVectors[i].i = 0.0;
  }
  long int info = 0;
  doublecomplex complexWorkAreaSize;
  long int workAreaSize = -1;
  double *workArea2 = new double[5 * minmn];

  // Determine optimal workareasize
  zgesvd_(&rowsOfU, &rowsOfVT, &_m, &_n, a, &lda, _singularValues,
          _leftSingularVectors, &_m, _rightSingularVectors, &_n,
          &complexWorkAreaSize, &workAreaSize, workArea2, &info);

  if (info == 0) {
    if (_verbose) std::cout << "zgesvd_..." << std::endl;
    workAreaSize = (int)complexWorkAreaSize.r;
    doublecomplex *workArea1 = new doublecomplex[workAreaSize];
    zgesvd_(&rowsOfU, &rowsOfVT, &_m, &_n, a, &lda, _singularValues,
            _leftSingularVectors, &_m, _rightSingularVectors, &_n, workArea1,
            &workAreaSize, workArea2, &info);

    delete[] workArea1;
  }
  delete[] workArea2;
  delete[] a;

  if (_verbose) {
    for (int i = 0; i < minmn; ++i) std::cout << _singularValues[i] << ",";
    std::cout << std::endl;
    std::cout << watch.ToString() << std::endl;
  }
}

void SVDMitigater::Compose() {
  if (_verbose) std::cout << "Composing..." << std::endl;
  Stopwatch watch;
  watch.Start();
  Image2DPtr real =
      Image2D::CreateUnsetImagePtr(_data.ImageWidth(), _data.ImageHeight());
  Image2DPtr imaginary =
      Image2D::CreateUnsetImagePtr(_data.ImageWidth(), _data.ImageHeight());
  int minmn = _m < _n ? _m : _n;
  for (int t = 0; t < _n; ++t) {
    for (int f = 0; f < _m; ++f) {
      double a_tf_r = 0.0;
      double a_tf_i = 0.0;
      // A = U S V^T , so:
      // a_tf = \sum_{g=0}^{minmn} U_{gf} S_{gg} V^T_{tg}
      // Note that _rightSingularVectors=V^T, thus is already stored rowwise
      for (int g = 0; g < minmn; ++g) {
        double u_r = _leftSingularVectors[g * _m + f].r;
        double u_i = _leftSingularVectors[g * _m + f].i;
        double s = _singularValues[g];
        double v_r = _rightSingularVectors[t * _n + g].r;
        double v_i = _rightSingularVectors[t * _n + g].i;
        a_tf_r += s * (u_r * v_r - u_i * v_i);
        a_tf_i += s * (u_r * v_i + u_i * v_r);
      }
      real->SetValue(t, f, a_tf_r);
      imaginary->SetValue(t, f, a_tf_i);
    }
  }
  if (_background != 0) delete _background;
  _background =
      new TimeFrequencyData(aocommon::Polarization::StokesI, real, imaginary);
  if (_verbose) std::cout << watch.ToString() << std::endl;
}

#ifdef HAVE_GTKMM

void SVDMitigater::CreateSingularValueGraph(const TimeFrequencyData &data,
                                            Plot2D &plot) {
  size_t polarisationCount = data.PolarizationCount();
  plot.SetTitle("Distribution of singular values");
  plot.SetLogarithmicYAxis(true);
  for (size_t i = 0; i < polarisationCount; ++i) {
    TimeFrequencyData polarizationData(data.MakeFromPolarizationIndex(i));
    SVDMitigater svd;
    svd.Initialize(polarizationData);
    svd.Decompose();
    size_t minmn = svd._m < svd._n ? svd._m : svd._n;

    Plot2DPointSet &pointSet = plot.StartLine(polarizationData.Description());
    pointSet.SetXDesc("Singular value index");
    pointSet.SetYDesc("Singular value");

    for (size_t i = 0; i < minmn; ++i)
      plot.PushDataPoint(i, svd.SingularValue(i));
  }
}

#else
void SVDMitigater::CreateSingularValueGraph(const TimeFrequencyData &,
                                            Plot2D &) {}
#endif