File: test_zspy.py

package info (click to toggle)
python-rosettasciio 0.7.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 144,644 kB
  • sloc: python: 36,638; xml: 2,582; makefile: 20; ansic: 4
file content (124 lines) | stat: -rw-r--r-- 4,420 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# -*- coding: utf-8 -*-
# Copyright 2007-2023 The HyperSpy developers
#
# This file is part of RosettaSciIO.
#
# RosettaSciIO is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RosettaSciIO is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with RosettaSciIO. If not, see <https://www.gnu.org/licenses/#GPL>.

import logging
import os

import numpy as np
import pytest

hs = pytest.importorskip("hyperspy.api", reason="hyperspy not installed")
# zarr (because of numcodecs) is only supported on x86_64 machines
zarr = pytest.importorskip("zarr", reason="zarr not installed")


class TestZspy:
    @pytest.fixture
    def signal(self):
        data = np.ones((10, 10, 10, 10))
        s = hs.signals.Signal1D(data)
        return s

    @pytest.mark.parametrize("store_class", [zarr.N5Store, zarr.ZipStore])
    def test_save_store(self, signal, tmp_path, store_class):
        filename = tmp_path / "testmodels.zspy"
        store = store_class(path=filename)
        signal.save(store)

        if store_class is zarr.ZipStore:
            assert os.path.isfile(filename)
        else:
            assert os.path.isdir(filename)

        store2 = store_class(path=filename)
        signal2 = hs.load(store2)

        np.testing.assert_array_equal(signal2.data, signal.data)

    def test_save_ZipStore_close_file(self, signal, tmp_path):
        filename = tmp_path / "testmodels.zspy"
        store = zarr.ZipStore(path=filename)
        signal.save(store, close_file=False)

        assert os.path.isfile(filename)

        store2 = zarr.ZipStore(path=filename)
        s2 = hs.load(store2)

        np.testing.assert_array_equal(s2.data, signal.data)

    def test_save_wrong_store(self, signal, tmp_path, caplog):
        filename = tmp_path / "testmodels.zspy"
        store = zarr.N5Store(path=filename)
        signal.save(store)

        store2 = zarr.N5Store(path=filename)
        s2 = hs.load(store2)
        np.testing.assert_array_equal(s2.data, signal.data)

        store2 = zarr.NestedDirectoryStore(path=filename)
        with pytest.raises(Exception):
            with caplog.at_level(logging.ERROR):
                _ = hs.load(store2)

    @pytest.mark.parametrize("overwrite", [None, True, False])
    def test_overwrite(self, signal, overwrite, tmp_path):
        filename = tmp_path / "testmodels.zspy"
        signal.save(filename=filename)
        signal2 = signal * 2
        signal2.save(filename=filename, overwrite=overwrite)
        if overwrite is None:
            np.testing.assert_array_equal(signal.data, hs.load(filename).data)
        elif overwrite:
            np.testing.assert_array_equal(signal2.data, hs.load(filename).data)
        else:
            np.testing.assert_array_equal(signal.data, hs.load(filename).data)

    def test_compression_opts(self, tmp_path):
        self.filename = tmp_path / "testfile.zspy"
        from numcodecs import Blosc

        comp = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE)
        hs.signals.BaseSignal([1, 2, 3]).save(self.filename, compressor=comp)
        f = zarr.open(self.filename.__str__(), mode="r+")
        d = f["Experiments/__unnamed__/data"]
        assert d.compressor == comp

    @pytest.mark.parametrize("compressor", (None, "default", "blosc"))
    def test_compression(self, compressor, tmp_path):
        if compressor == "blosc":
            from numcodecs import Blosc

            compressor = Blosc(cname="zstd", clevel=3, shuffle=Blosc.BITSHUFFLE)
        s = hs.signals.Signal1D(np.ones((3, 3)))
        s.save(
            tmp_path / "test_compression.zspy", overwrite=True, compressor=compressor
        )
        _ = hs.load(tmp_path / "test_compression.zspy")


def test_non_valid_zspy(tmp_path, caplog):
    filename = tmp_path / "testfile.zspy"
    data = np.arange(10)

    f = zarr.group(filename)
    f.create_dataset("dataset", data=data)

    with pytest.raises(IOError):
        with caplog.at_level(logging.ERROR):
            _ = hs.load(filename)