1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
# Author: Gael Varoquaux <gael.varoquaux@normalesup.org>
# Copyright (c) 2010, Enthought, Inc.
# License: BSD Style.
# Standard library imports.
import unittest
import numpy as np
# Local imports.
from mayavi.core.null_engine import NullEngine
# Enthought library imports
from mayavi.filters.threshold import Threshold
from mayavi.filters.cut_plane import CutPlane
from mayavi.sources.array_source import ArraySource
class TestThresholdFilter(unittest.TestCase):
def make_src(self, nan=False):
data = np.empty((3, 3, 3))
if nan:
data[0] = np.nan
data.flat[:] = np.arange(data.size)
return ArraySource(scalar_data=data)
def setUp(self):
"""Initial setting up of test fixture, automatically called by
TestCase before any other test method is invoked"""
e = NullEngine()
# Uncomment to see visualization for debugging etc.
#e = Engine()
e.start()
s = e.new_scene()
self.e = e
self.s = s
self.scene = e.current_scene
return
def tearDown(self):
"""For necessary clean up, automatically called by TestCase
after the test methods have been invoked"""
self.e.stop()
return
def test_threshold_filter_nan(self):
src = self.make_src(nan=True)
self.e.add_source(src)
threshold = Threshold()
self.e.add_filter(threshold)
self.assertEqual(
np.nanmin(src.scalar_data),
np.nanmin(
threshold.get_output_dataset().point_data.scalars.to_array()
)
)
self.assertEqual(
np.nanmax(src.scalar_data),
np.nanmax(
threshold.get_output_dataset().point_data.scalars.to_array()
)
)
def test_threshold_filter_threhsold(self):
src = self.make_src()
self.e.add_source(src)
threshold = Threshold()
self.e.add_filter(threshold)
threshold.upper_threshold = 20.
self.assertTrue(
20 >= np.nanmax(
threshold.get_output_dataset().point_data.scalars.to_array()
)
)
return
def test_threshold_filter_data_range_changes(self):
# Regression test for GitHub issue #136.
src = self.make_src()
self.e.add_source(src)
threshold = Threshold()
self.e.add_filter(threshold)
# Move from one data range to another non-overlapping range,
# first downwards, then back up.
src.scalar_data = np.linspace(3.0, 5.0, 27).reshape((3, 3, 3))
self.assertAlmostEqual(threshold.lower_threshold, 3.0)
self.assertAlmostEqual(threshold.upper_threshold, 5.0)
src.scalar_data = np.linspace(-5.0, -3.0, 27).reshape((3, 3, 3))
self.assertAlmostEqual(threshold.lower_threshold, -5.0)
self.assertAlmostEqual(threshold.upper_threshold, -3.0)
src.scalar_data = np.linspace(3.0, 5.0, 27).reshape((3, 3, 3))
self.assertAlmostEqual(threshold.lower_threshold, 3.0)
self.assertAlmostEqual(threshold.upper_threshold, 5.0)
# Narrow and widen.
src.scalar_data = np.linspace(4.2, 4.6, 27).reshape((3, 3, 3))
self.assertAlmostEqual(threshold.lower_threshold, 4.2)
self.assertAlmostEqual(threshold.upper_threshold, 4.6)
src.scalar_data = np.linspace(-20.0, 20.0, 27).reshape((3, 3, 3))
self.assertAlmostEqual(threshold.lower_threshold, -20.0)
self.assertAlmostEqual(threshold.upper_threshold, 20.0)
# Shift to a range overlapping the previous one.
src.scalar_data = np.linspace(-10.0, -30.0, 27).reshape((3, 3, 3))
self.assertAlmostEqual(threshold.lower_threshold, -30.0)
self.assertAlmostEqual(threshold.upper_threshold, -10.0)
src.scalar_data = np.linspace(-20.0, 20.0, 27).reshape((3, 3, 3))
self.assertAlmostEqual(threshold.lower_threshold, -20.0)
self.assertAlmostEqual(threshold.upper_threshold, 20.0)
def test_threshold_with_other_filter_as_input(self):
# Given
x, y, z = np.mgrid[-1:1:10j, -1:1:10j, -1:1:10j]
s = x*x + y*y + z*z
src = ArraySource(scalar_data=s)
self.e.add_source(src)
scp = CutPlane()
self.e.add_filter(scp)
# When
threshold = Threshold()
self.e.add_filter(threshold)
threshold.trait_set(
lower_threshold=0.25, upper_threshold=0.75,
auto_reset_lower=False, auto_reset_upper=False
)
# Then
output = threshold.get_output_dataset()
self.assertTrue(output is not None)
self.assertTrue(output.is_a('vtkUnstructuredGrid'))
output_range = output.point_data.scalars.range
self.assertTrue(output_range[0] >= 0.25)
self.assertTrue(output_range[1] <= 0.75)
if __name__ == '__main__':
unittest.main()
|