File: test_optomo.py

package info (click to toggle)
astra-toolbox 2.3.0-4
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid
  • size: 4,972 kB
  • sloc: cpp: 24,378; python: 5,048; sh: 3,514; ansic: 1,181; makefile: 518
file content (118 lines) | stat: -rw-r--r-- 4,048 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import astra
import pytest
import numpy as np
import scipy

DET_SPACING_X = 1.0
DET_SPACING_Y = 1.0
DET_ROW_COUNT = 20
DET_COL_COUNT = 45
N_ANGLES = 180
ANGLES = np.linspace(0, 2 * np.pi, N_ANGLES, endpoint=False)
N_ROWS = 40
N_COLS = 30
N_SLICES = 50


@pytest.fixture
def op_tomo(dimensionality):
    if dimensionality == '2d':
        vol_geom = astra.create_vol_geom(N_ROWS, N_COLS)
        proj_geom = astra.create_proj_geom('parallel', DET_SPACING_Y, DET_COL_COUNT, ANGLES)
        projector_id = astra.create_projector('cuda', proj_geom, vol_geom)
    elif dimensionality == '3d':
        vol_geom = astra.create_vol_geom(N_ROWS, N_COLS, N_SLICES)
        proj_geom = astra.create_proj_geom('parallel3d', DET_SPACING_X, DET_SPACING_Y,
                                           DET_ROW_COUNT, DET_COL_COUNT, ANGLES)
        projector_id = astra.create_projector('cuda3d', proj_geom, vol_geom)
    yield astra.OpTomo(projector_id)
    astra.projector.delete(projector_id)


@pytest.fixture
def vol_data(dimensionality):
    if dimensionality == '2d':
        return np.ones([N_ROWS, N_COLS], dtype=np.float32)
    elif dimensionality == '3d':
        return np.ones([N_SLICES, N_ROWS, N_COLS], dtype=np.float32)


@pytest.fixture
def vol_buffer(dimensionality):
    if dimensionality == '2d':
        return np.zeros([N_ROWS, N_COLS], dtype=np.float32)
    elif dimensionality == '3d':
        return np.zeros([N_SLICES, N_ROWS, N_COLS], dtype=np.float32)


@pytest.fixture
def proj_data(dimensionality):
    if dimensionality == '2d':
        return np.ones([N_ANGLES, DET_COL_COUNT], dtype=np.float32)
    elif dimensionality == '3d':
        return np.ones([DET_ROW_COUNT, N_ANGLES, DET_COL_COUNT], dtype=np.float32)


@pytest.fixture
def proj_buffer(dimensionality):
    if dimensionality == '2d':
        return np.zeros([N_ANGLES, DET_COL_COUNT], dtype=np.float32)
    elif dimensionality == '3d':
        return np.zeros([DET_ROW_COUNT, N_ANGLES, DET_COL_COUNT], dtype=np.float32)


@pytest.fixture
def algorithm(dimensionality):
    if dimensionality == '2d':
         return 'SIRT_CUDA'
    elif dimensionality == '3d':
        return 'SIRT3D_CUDA'


@pytest.mark.parametrize('dimensionality', ['2d', '3d'])
class TestAll:
    def test_fp(self, dimensionality, op_tomo, vol_data):
        fp = op_tomo.FP(vol_data)
        assert not np.allclose(fp, 0.0)

    def test_fp_flattened(self, dimensionality, op_tomo, vol_data):
        fp = op_tomo.FP(vol_data.flatten())
        assert not np.allclose(fp, 0.0)

    def test_fp_out_arg(self, dimensionality, op_tomo, vol_data, proj_buffer):
        op_tomo.FP(vol_data, out=proj_buffer)
        assert not np.allclose(proj_buffer, 0.0)

    def test_bp(self, dimensionality, op_tomo, proj_data):
        bp = op_tomo.BP(proj_data)
        assert not np.allclose(bp, 0.0)

    def test_bp_flattened(self, dimensionality, op_tomo, proj_data):
        bp = op_tomo.BP(proj_data.flatten())
        assert not np.allclose(bp, 0.0)

    def test_bp_out_arg(self, dimensionality, op_tomo, proj_data, vol_buffer):
        op_tomo.BP(proj_data, out=vol_buffer)
        assert not np.allclose(vol_buffer, 0.0)

    def test_matvec(self, dimensionality, op_tomo, vol_data):
        fp = op_tomo(vol_data)
        assert not np.allclose(fp, 0.0)

    def test_rmatvec(self, dimensionality, op_tomo, proj_data):
        bp = op_tomo.T(proj_data)
        assert not np.allclose(bp, 0.0)

    def test_mul(self, dimensionality, op_tomo, vol_data):
        fp = op_tomo * vol_data
        assert not np.allclose(fp, 0.0)

    def test_reconstruct(self, dimensionality, op_tomo, proj_data, algorithm):
        rec = op_tomo.reconstruct(algorithm, proj_data, iterations=2,
                                  extraOptions={'MinConstraint': 0.0})
        assert not np.allclose(rec, 0.0)

    def test_scipy_solver(self, dimensionality, op_tomo, proj_data):
        result = scipy.sparse.linalg.lsqr(op_tomo, proj_data.flatten(), iter_lim=2)
        rec = result[0]
        assert not np.allclose(rec, 0.0)