1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
|
#include "../../util/stopwatch.h"
#include "svdmitigater.h"
#ifdef HAVE_GTKMM
#include "../../plot/plot2d.h"
#endif
extern "C" {
int zgesvd_(char *jobu, char *jobvt, integer *m, integer *n, doublecomplex *a,
integer *lda, doublereal *s, doublecomplex *u, integer *ldu,
doublecomplex *vt, integer *ldvt, doublecomplex *work,
integer *lwork, doublereal *rwork, integer *info);
}
SVDMitigater::SVDMitigater()
: _background(0),
_singularValues(0),
_leftSingularVectors(0),
_rightSingularVectors(0),
_m(0),
_n(0),
_iteration(0),
_removeCount(10),
_verbose(false) {}
SVDMitigater::~SVDMitigater() { Clear(); }
void SVDMitigater::Clear() {
if (IsDecomposed()) {
delete[] _singularValues;
delete[] _leftSingularVectors;
delete[] _rightSingularVectors;
_singularValues = 0;
_leftSingularVectors = 0;
_rightSingularVectors = 0;
}
if (_background != 0) delete _background;
}
// lda = leading dimension
/*int cgebrd_(integer *m, integer *n, complex *a, integer *lda,
real *d__, real *e, complex *tauq, complex *taup, complex *work,
integer *lwork, integer *info);*/
void SVDMitigater::Decompose() {
if (_verbose) std::cout << "Decomposing..." << std::endl;
Stopwatch watch;
watch.Start();
Clear();
// Remember that the axes have to be turned; in 'a', time is along the
// vertical axis.
_m = _data.ImageHeight();
_n = _data.ImageWidth();
int minmn = _m < _n ? _m : _n;
char rowsOfU = 'A'; // all rows of u
char rowsOfVT = 'A'; // all rows of VT
doublecomplex *a = new doublecomplex[_m * _n];
Image2DCPtr real = _data.GetRealPart(), imaginary = _data.GetImaginaryPart();
for (int t = 0; t < _n; ++t) {
for (int f = 0; f < _m; ++f) {
a[t * _m + f].r = real->Value(t, f);
a[t * _m + f].i = imaginary->Value(t, f);
}
}
long int lda = _m;
_singularValues = new double[minmn];
for (int i = 0; i < minmn; ++i) _singularValues[i] = 0.0;
_leftSingularVectors = new doublecomplex[_m * _m];
for (int i = 0; i < _m * _m; ++i) {
_leftSingularVectors[i].r = 0.0;
_leftSingularVectors[i].i = 0.0;
}
_rightSingularVectors = new doublecomplex[_n * _n];
for (int i = 0; i < _n * _n; ++i) {
_rightSingularVectors[i].r = 0.0;
_rightSingularVectors[i].i = 0.0;
}
long int info = 0;
doublecomplex complexWorkAreaSize;
long int workAreaSize = -1;
double *workArea2 = new double[5 * minmn];
// Determine optimal workareasize
zgesvd_(&rowsOfU, &rowsOfVT, &_m, &_n, a, &lda, _singularValues,
_leftSingularVectors, &_m, _rightSingularVectors, &_n,
&complexWorkAreaSize, &workAreaSize, workArea2, &info);
if (info == 0) {
if (_verbose) std::cout << "zgesvd_..." << std::endl;
workAreaSize = (int)complexWorkAreaSize.r;
doublecomplex *workArea1 = new doublecomplex[workAreaSize];
zgesvd_(&rowsOfU, &rowsOfVT, &_m, &_n, a, &lda, _singularValues,
_leftSingularVectors, &_m, _rightSingularVectors, &_n, workArea1,
&workAreaSize, workArea2, &info);
delete[] workArea1;
}
delete[] workArea2;
delete[] a;
if (_verbose) {
for (int i = 0; i < minmn; ++i) std::cout << _singularValues[i] << ",";
std::cout << std::endl;
std::cout << watch.ToString() << std::endl;
}
}
void SVDMitigater::Compose() {
if (_verbose) std::cout << "Composing..." << std::endl;
Stopwatch watch;
watch.Start();
Image2DPtr real =
Image2D::CreateUnsetImagePtr(_data.ImageWidth(), _data.ImageHeight());
Image2DPtr imaginary =
Image2D::CreateUnsetImagePtr(_data.ImageWidth(), _data.ImageHeight());
int minmn = _m < _n ? _m : _n;
for (int t = 0; t < _n; ++t) {
for (int f = 0; f < _m; ++f) {
double a_tf_r = 0.0;
double a_tf_i = 0.0;
// A = U S V^T , so:
// a_tf = \sum_{g=0}^{minmn} U_{gf} S_{gg} V^T_{tg}
// Note that _rightSingularVectors=V^T, thus is already stored rowwise
for (int g = 0; g < minmn; ++g) {
double u_r = _leftSingularVectors[g * _m + f].r;
double u_i = _leftSingularVectors[g * _m + f].i;
double s = _singularValues[g];
double v_r = _rightSingularVectors[t * _n + g].r;
double v_i = _rightSingularVectors[t * _n + g].i;
a_tf_r += s * (u_r * v_r - u_i * v_i);
a_tf_i += s * (u_r * v_i + u_i * v_r);
}
real->SetValue(t, f, a_tf_r);
imaginary->SetValue(t, f, a_tf_i);
}
}
if (_background != 0) delete _background;
_background =
new TimeFrequencyData(aocommon::Polarization::StokesI, real, imaginary);
if (_verbose) std::cout << watch.ToString() << std::endl;
}
#ifdef HAVE_GTKMM
void SVDMitigater::CreateSingularValueGraph(const TimeFrequencyData &data,
Plot2D &plot) {
size_t polarisationCount = data.PolarizationCount();
plot.SetTitle("Distribution of singular values");
plot.SetLogarithmicYAxis(true);
for (size_t i = 0; i < polarisationCount; ++i) {
TimeFrequencyData polarizationData(data.MakeFromPolarizationIndex(i));
SVDMitigater svd;
svd.Initialize(polarizationData);
svd.Decompose();
size_t minmn = svd._m < svd._n ? svd._m : svd._n;
Plot2DPointSet &pointSet = plot.StartLine(polarizationData.Description());
pointSet.SetXDesc("Singular value index");
pointSet.SetYDesc("Singular value");
for (size_t i = 0; i < minmn; ++i)
plot.PushDataPoint(i, svd.SingularValue(i));
}
}
#else
void SVDMitigater::CreateSingularValueGraph(const TimeFrequencyData &,
Plot2D &) {}
#endif
|