1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
|
#include <catch2/catch_all.hpp>
#include <complex>
#include <iomanip>
#include "sopt/chained_operators.h"
#include "sopt/linear_transform.h"
TEST_CASE("Linear Transforms", "[ops]") {
using namespace sopt;
using SCALAR = int;
using t_Vector = Vector<SCALAR>;
auto constexpr N = 5;
SECTION("1 Functions") {
OperatorFunction<t_Vector> const func0 = [](t_Vector &out, t_Vector const &input) {
out = input.array() * 2 - 1;
};
OperatorFunction<t_Vector> const func1 = [](t_Vector &out, t_Vector const &input) {
out = input.array() * 4 - 1;
};
t_Vector const x = t_Vector::Random(2 * N) * 5;
auto chain = chained_operators(func0);
t_Vector actual;
t_Vector expected;
func0(actual, x);
chain(expected, x);
CHECK(actual == expected);
}
SECTION("2 Functions") {
OperatorFunction<t_Vector> const func0 = [](t_Vector &out, t_Vector const &input) {
out = input.array() * 2 - 1;
};
OperatorFunction<t_Vector> const func1 = [](t_Vector &out, t_Vector const &input) {
out = input.array() * 4 - 1;
};
t_Vector const x = t_Vector::Random(2 * N) * 5;
auto chain = chained_operators(func0, func1);
t_Vector actual;
t_Vector expected;
func1(expected, x);
func0(actual, expected);
chain(expected, x);
CHECK(actual == expected);
}
SECTION("3 Functions") {
OperatorFunction<t_Vector> const func0 = [](t_Vector &out, t_Vector const &input) {
out = input.array() * 2 - 1;
};
OperatorFunction<t_Vector> const func1 = [](t_Vector &out, t_Vector const &input) {
out = input.array() * 4 - 1;
};
t_Vector const x = t_Vector::Random(2 * N) * 5;
auto chain = chained_operators(func0, func1, func0);
t_Vector actual;
t_Vector expected;
func0(actual, x);
func1(expected, actual);
func0(actual, expected);
chain(expected, x);
CHECK(actual == expected);
}
SECTION("4 Functions") {
OperatorFunction<t_Vector> const func0 = [](t_Vector &out, t_Vector const &input) {
out = input.array() * 2 - 1;
};
OperatorFunction<t_Vector> const func1 = [](t_Vector &out, t_Vector const &input) {
out = input.array() * 4 - 1;
};
t_Vector const x = t_Vector::Random(2 * N) * 5;
auto chain = chained_operators(func0, func1, func0, func0);
t_Vector actual;
t_Vector expected;
func0(expected, x);
func0(actual, expected);
func1(expected, actual);
func0(actual, expected);
chain(expected, x);
CHECK(actual == expected);
}
SECTION("linear transform") {
OperatorFunction<t_Vector> const func0 = [](t_Vector &out, t_Vector const &input) {
out = input.head(N - 1).array() * 2;
};
OperatorFunction<t_Vector> const func1 = [](t_Vector &out, t_Vector const &input) {
out = input.head(N - 1).array() * 4;
};
OperatorFunction<t_Vector> const afunc0 = [](t_Vector &out, t_Vector const &input) {
out = t_Vector::Zero(N);
out.head(N - 1) = input.head(N - 1).array() * 2;
};
OperatorFunction<t_Vector> const afunc1 = [](t_Vector &out, t_Vector const &input) {
out = t_Vector::Zero(N);
out.head(N - 1) = input.head(N - 1).array() * 4;
};
t_Vector const x = t_Vector::Random(2 * N) * 5;
auto chain = chained_operators(func0, func1, func0, func0);
auto chain_adjoint = chained_operators(afunc0, afunc0, afunc1, afunc0);
auto op = LinearTransform<t_Vector>{chain, chain_adjoint};
t_Vector actual;
t_Vector expected;
chain(actual, x);
expected = op * x;
CHECK(expected == x.head(N - 1) * 32);
CHECK(actual == expected);
chain_adjoint(actual, x.head(N - 1));
expected = op.adjoint() * x.head(N - 1);
CHECK(actual == expected);
}
}
|