File: svdmitigater.cpp

package info (click to toggle)
aoflagger 3.4.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 8,960 kB
  • sloc: cpp: 83,076; python: 10,187; sh: 260; makefile: 178
file content (176 lines) | stat: -rw-r--r-- 5,785 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#include "../util/stopwatch.h"

#include "svdmitigater.h"

#ifdef HAVE_GTKMM
#include "../plot/xyplot.h"
#endif

extern "C" {
int zgesvd_(char* jobu, char* jobvt, integer* m, integer* n, doublecomplex* a,
            integer* lda, doublereal* s, doublecomplex* u, integer* ldu,
            doublecomplex* vt, integer* ldvt, doublecomplex* work,
            integer* lwork, doublereal* rwork, integer* info);
}

namespace algorithms {

SVDMitigater::SVDMitigater()
    : _background(nullptr),
      _singularValues(nullptr),
      _leftSingularVectors(nullptr),
      _rightSingularVectors(nullptr),
      _m(0),
      _n(0),
      _iteration(0),
      _removeCount(10),
      _verbose(false) {}

SVDMitigater::~SVDMitigater() { Clear(); }

void SVDMitigater::Clear() {
  if (IsDecomposed()) {
    delete[] _singularValues;
    delete[] _leftSingularVectors;
    delete[] _rightSingularVectors;
    _singularValues = nullptr;
    _leftSingularVectors = nullptr;
    _rightSingularVectors = nullptr;
  }
  if (_background != nullptr) delete _background;
}

// lda = leading dimension

/*int cgebrd_(integer *m, integer *n, complex *a, integer *lda,
         real *d__, real *e, complex *tauq, complex *taup, complex *work,
        integer *lwork, integer *info);*/

void SVDMitigater::Decompose() {
  if (_verbose) std::cout << "Decomposing..." << std::endl;
  Stopwatch watch;
  watch.Start();
  Clear();

  // Remember that the axes have to be turned; in 'a', time is along the
  // vertical axis.
  _m = _data.ImageHeight();
  _n = _data.ImageWidth();
  const int minmn = _m < _n ? _m : _n;
  char rowsOfU = 'A';   // all rows of u
  char rowsOfVT = 'A';  // all rows of VT
  doublecomplex* a = new doublecomplex[_m * _n];
  Image2DCPtr real = _data.GetRealPart(), imaginary = _data.GetImaginaryPart();
  for (int t = 0; t < _n; ++t) {
    for (int f = 0; f < _m; ++f) {
      a[t * _m + f].r = real->Value(t, f);
      a[t * _m + f].i = imaginary->Value(t, f);
    }
  }
  long int lda = _m;
  _singularValues = new double[minmn];
  for (int i = 0; i < minmn; ++i) _singularValues[i] = 0.0;
  _leftSingularVectors = new doublecomplex[_m * _m];
  for (int i = 0; i < _m * _m; ++i) {
    _leftSingularVectors[i].r = 0.0;
    _leftSingularVectors[i].i = 0.0;
  }
  _rightSingularVectors = new doublecomplex[_n * _n];
  for (int i = 0; i < _n * _n; ++i) {
    _rightSingularVectors[i].r = 0.0;
    _rightSingularVectors[i].i = 0.0;
  }
  long int info = 0;
  doublecomplex complexWorkAreaSize;
  long int workAreaSize = -1;
  double* workArea2 = new double[5 * minmn];

  // Determine optimal workareasize
  zgesvd_(&rowsOfU, &rowsOfVT, &_m, &_n, a, &lda, _singularValues,
          _leftSingularVectors, &_m, _rightSingularVectors, &_n,
          &complexWorkAreaSize, &workAreaSize, workArea2, &info);

  if (info == 0) {
    if (_verbose) std::cout << "zgesvd_..." << std::endl;
    workAreaSize = (int)complexWorkAreaSize.r;
    doublecomplex* workArea1 = new doublecomplex[workAreaSize];
    zgesvd_(&rowsOfU, &rowsOfVT, &_m, &_n, a, &lda, _singularValues,
            _leftSingularVectors, &_m, _rightSingularVectors, &_n, workArea1,
            &workAreaSize, workArea2, &info);

    delete[] workArea1;
  }
  delete[] workArea2;
  delete[] a;

  if (_verbose) {
    for (int i = 0; i < minmn; ++i) std::cout << _singularValues[i] << ",";
    std::cout << std::endl;
    std::cout << watch.ToString() << std::endl;
  }
}

void SVDMitigater::Compose() {
  if (_verbose) std::cout << "Composing..." << std::endl;
  Stopwatch watch;
  watch.Start();
  const Image2DPtr real =
      Image2D::CreateUnsetImagePtr(_data.ImageWidth(), _data.ImageHeight());
  const Image2DPtr imaginary =
      Image2D::CreateUnsetImagePtr(_data.ImageWidth(), _data.ImageHeight());
  const int minmn = _m < _n ? _m : _n;
  for (int t = 0; t < _n; ++t) {
    for (int f = 0; f < _m; ++f) {
      double a_tf_r = 0.0;
      double a_tf_i = 0.0;
      // A = U S V^T , so:
      // a_tf = \sum_{g=0}^{minmn} U_{gf} S_{gg} V^T_{tg}
      // Note that _rightSingularVectors=V^T, thus is already stored rowwise
      for (int g = 0; g < minmn; ++g) {
        const double u_r = _leftSingularVectors[g * _m + f].r;
        const double u_i = _leftSingularVectors[g * _m + f].i;
        const double s = _singularValues[g];
        const double v_r = _rightSingularVectors[t * _n + g].r;
        const double v_i = _rightSingularVectors[t * _n + g].i;
        a_tf_r += s * (u_r * v_r - u_i * v_i);
        a_tf_i += s * (u_r * v_i + u_i * v_r);
      }
      real->SetValue(t, f, a_tf_r);
      imaginary->SetValue(t, f, a_tf_i);
    }
  }
  if (_background != nullptr) delete _background;
  _background =
      new TimeFrequencyData(aocommon::Polarization::StokesI, real, imaginary);
  if (_verbose) std::cout << watch.ToString() << std::endl;
}

#ifdef HAVE_GTKMM

void SVDMitigater::CreateSingularValueGraph(const TimeFrequencyData& data,
                                            XYPlot& plot) {
  const size_t polarisationCount = data.PolarizationCount();
  plot.SetTitle("Distribution of singular values");
  plot.YAxis().SetLogarithmic(true);
  for (size_t i = 0; i < polarisationCount; ++i) {
    const TimeFrequencyData polarizationData(data.MakeFromPolarizationIndex(i));
    SVDMitigater svd;
    svd.Initialize(polarizationData);
    svd.Decompose();
    const size_t minmn = svd._m < svd._n ? svd._m : svd._n;

    XYPointSet& pointSet = plot.StartLine(polarizationData.Description());
    pointSet.SetXDesc("Singular value index");
    pointSet.SetYDesc("Singular value");

    for (size_t i = 0; i < minmn; ++i)
      plot.PushDataPoint(i, svd.SingularValue(i));
  }
}

#else
void SVDMitigater::CreateSingularValueGraph(const TimeFrequencyData&, XYPlot&) {
}
#endif

}  // namespace algorithms