File: image.py

package info (click to toggle)
mpl-animators 1.2.4-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 360 kB
  • sloc: python: 1,386; makefile: 18
file content (119 lines) | stat: -rw-r--r-- 5,056 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import matplotlib as mpl

from .base import ArrayAnimator

__all__ = ['ImageAnimator']


class ImageAnimator(ArrayAnimator):
    """
    Create a matplotlib backend independent data explorer for 2D images.

    The following keyboard shortcuts are defined in the viewer:

    * 'left': previous step on active slider.
    * 'right': next step on active slider.
    * 'top': change the active slider up one.
    * 'bottom': change the active slider down one.
    * 'p': play/pause active slider.

    This viewer can have user defined buttons added by specifying the labels
    and functions called when those buttons are clicked as keyword arguments.

    Parameters
    ----------
    data: `numpy.ndarray`
        The data to be visualized.
    image_axes: `list`, optional
        A list of the axes order that make up the image.
    axis_ranges: `list` of physical coordinates for the `numpy.ndarray`, optional
        Defaults to `None` and array indices will be used for all axes.
        The `list` should contain one element for each axis of the `numpy.ndarray`.
        For the image axes a ``[min, max]`` pair should be specified which will be
        passed to `matplotlib.pyplot.imshow` as an extent.
        For the slider axes a ``[min, max]`` pair can be specified or an array the
        same length as the axis which will provide all values for that slider.

    Notes
    -----
    Extra keywords are passed to `~sunpy.visualization.animator.ArrayAnimator`.
    """

    def __init__(self, data, image_axes=[-2, -1], axis_ranges=None, **kwargs):
        # Check that number of axes is 2.
        if len(image_axes) != 2:
            raise ValueError("There can only be two spatial axes")
        # Define number of slider axes.
        self.naxis = data.ndim
        self.num_sliders = self.naxis-2
        # Define marker to determine if plot axes values are supplied via array of
        # pixel values or min max pair. This will determine the type of image produced
        # and hence how to plot and update it.
        self._non_regular_plot_axis = False
        # Run init for parent class
        super().__init__(data, image_axes=image_axes, axis_ranges=axis_ranges, **kwargs)

    def plot_start_image(self, ax):
        """
        Sets up plot of initial image.
        """
        # Create extent arg
        extent = []
        # reverse because numpy is in y-x and extent is x-y
        if max([len(self.axis_ranges[i]) for i in self.image_axes[::-1]]) > 2:
            self._non_regular_plot_axis = True
        for i in self.image_axes[::-1]:
            if self._non_regular_plot_axis is False and len(self.axis_ranges[i]) > 2:
                self._non_regular_plot_axis = True
            extent.append(self.axis_ranges[i][0])
            extent.append(self.axis_ranges[i][-1])

        imshow_args = {'interpolation': 'nearest',
                       'origin': 'lower'}
        imshow_args.update(self.imshow_kwargs)

        # If value along an axis is set with an array, generate a NonUniformImage
        if self._non_regular_plot_axis:
            # If user has inverted the axes, transpose the data so the dimensions match.
            if self.image_axes[0] < self.image_axes[1]:
                data = self.data[self.frame_index].transpose()
            else:
                data = self.data[self.frame_index]
            # Initialize a NonUniformImage with the relevant data and axis values and
            # add the image to the axes.
            im = mpl.image.NonUniformImage(ax, **imshow_args)
            im.set_data(self.axis_ranges[self.image_axes[0]],
                        self.axis_ranges[self.image_axes[1]], data)
            ax.add_image(im)
            # Define the xlim and ylim from the pixel edges.
            ax.set_xlim(self.extent[0], self.extent[1])
            ax.set_ylim(self.extent[2], self.extent[3])
        else:
            # Else produce a more basic plot with regular axes.
            imshow_args.update({'extent': extent})
            im = ax.imshow(self.data[self.frame_index], **imshow_args)
        if self.if_colorbar:
            self._add_colorbar(im)

        return im

    def update_plot(self, val, im, slider):
        """
        Updates plot based on slider/array dimension being iterated.
        """
        ind = int(val)
        ax_ind = self.slider_axes[slider.slider_ind]
        self.frame_slice[ax_ind] = ind
        if val != slider.cval:
            if self._non_regular_plot_axis:
                if self.image_axes[0] < self.image_axes[1]:
                    data = self.data[self.frame_index].transpose()
                else:
                    data = self.data[self.frame_index]
                im.set_data(self.axis_ranges[self.image_axes[0]],
                            self.axis_ranges[self.image_axes[1]], data)
            else:
                im.set_array(self.data[self.frame_index])
            slider.cval = val
        # Update slider label to reflect real world values in axis_ranges.
        super().update_plot(val, im, slider)