1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
|
/* Copyright 2021. Martin Uecker.
* All rights reserved. Use of this source code is governed by
* a BSD-style license which can be found in the LICENSE file.
*
* Authors:
* 2020-2021 Martin Uecker <martin.uecker@med.uni-goettingen.de>
*/
#include <complex.h>
#include "num/multind.h"
#include "num/flpmath.h"
#include "num/ops_p.h"
#include "num/ops.h"
#include "misc/misc.h"
#include "misc/debug.h"
#include "misc/types.h"
#include "nlops/nlop.h"
#include "nlops/tenmul.h"
#include "nlops/chain.h"
#include "iter/prox2.h"
#include "iter/thresh.h"
#include "utest.h"
static bool test_nlgrad(void)
{
enum { N = 1 };
long dims[N] = { 1 };
auto nlop = nlop_tenmul_create(N, dims, dims, dims);
auto sq = nlop_dup(nlop, 0, 1);
nlop_free(nlop);
complex float* src = md_alloc(N, dims, CFL_SIZE);
complex float* dst = md_alloc(N, dims, CFL_SIZE);
md_zfill(N, dims, src, 1.);
auto p = prox_nlgrad_create(sq, 30, 0.1, 1.);
nlop_free(sq);
// argmin_x 0.5 (x - 1)^2 + x^2 = 1.5 x^2 -1x + 0.5
operator_p_apply(p, 1., N, dims, dst, N, dims, src);
operator_p_free(p);
md_zfill(N, dims, src, 1. / 3.);
float err = md_znrmse(N, dims, dst, src);
md_free(src);
md_free(dst);
UT_ASSERT(err < 1.E-4);
}
UT_REGISTER_TEST(test_nlgrad);
static bool test_auto_norm(void)
{
enum { N = 3 };
long dims[N] = { 2, 4, 3 };
complex float* src = md_alloc(N, dims, CFL_SIZE);
complex float* dst = md_alloc(N, dims, CFL_SIZE);
md_zfill(N, dims, src, 3.);
auto p = prox_thresh_create(N, dims, 0.5, 0u);
auto n = op_p_auto_normalize(p, MD_BIT(1), NORM_L2);
operator_p_free(p);
operator_p_apply(n, 0.5, N, dims, dst, N, dims, src);
operator_p_free(n);
md_zfill(N, dims, src, 3. * 0.5);
float err = md_znrmse(N, dims, dst, src);
md_free(src);
md_free(dst);
#ifdef __clang__
UT_ASSERT(err < 1.E-6);
#else
#if __GNUC__ >= 10
UT_ASSERT(err < 1.E-7);
#else
UT_ASSERT(err < 1.E-10);
#endif
#endif
}
UT_REGISTER_TEST(test_auto_norm);
|