File: test_masks.py

package info (click to toggle)
python-nexusformat 1.0.6-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 516 kB
  • sloc: python: 5,791; makefile: 5; sh: 1
file content (49 lines) | stat: -rw-r--r-- 1,671 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os

import numpy as np
import pytest
from nexusformat.nexus.tree import NXfield, NXgroup, NXroot, nxload


def test_field_masks(arr1D):
    field = NXfield(arr1D)
    field[10:20] = np.ma.masked

    assert isinstance(field.nxvalue, np.ma.masked_array)
    assert np.all(field[8:12].mask == np.array([False, False, True, True]))
    assert np.all(field.mask[8:12] == np.array([False, False, True, True]))
    assert np.ma.is_masked(field[8:12].nxvalue)
    assert np.ma.is_masked(field.nxvalue[10])
    assert np.ma.is_masked(field[10].nxvalue)
    assert field[10].mask

    field.mask[10] = np.ma.nomask

    assert np.all(field.mask[8:12] == np.array([False, False, False, True]))
    assert not field[10].mask


@pytest.mark.parametrize("save", ["False", "True"])
def test_group_masks(tmpdir, arr1D, save):
    group = NXgroup(NXfield(arr1D, name='field'))
    group['field'][10:20] = np.ma.masked

    if save:
        root = NXroot(group)
        filename = os.path.join(tmpdir, "file1.nxs")
        root.save(filename, mode="w")
        root = nxload(filename, "rw")
        group = root['group']

    assert isinstance(group['field'].nxvalue, np.ma.masked_array)
    assert np.all(group['field'].mask[9:11] == np.array([False, True]))
    assert 'mask' in group['field'].attrs
    assert group['field'].attrs['mask'] == 'field_mask'
    assert 'field_mask' in group
    assert group['field_mask'].dtype == bool
    assert group['field'].mask == group['field_mask']

    group['field'].mask[10] = np.ma.nomask

    assert np.all(group['field'].mask[10:12] == np.array([False, True]))
    assert np.all(group['field_mask'][10:12] == np.array([False, True]))