File: geocollection.py

package info (click to toggle)
python-cartopy 0.25.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 13,152 kB
  • sloc: python: 16,526; makefile: 159; javascript: 66
file content (122 lines) | stat: -rw-r--r-- 4,708 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright Crown and Cartopy Contributors
#
# This file is part of Cartopy and is released under the BSD 3-clause license.
# See LICENSE in the root of the repository for full licensing details.
from matplotlib.collections import QuadMesh
import numpy as np
import numpy.ma as ma

from cartopy.mpl import _MPL_38


def _split_wrapped_mesh_data(C, mask):
    """
    Helper function for splitting GeoQuadMesh array values between the
    pcolormesh and pcolor objects when wrapping.  Apply a mask to the grid
    cells that should not be plotted with each method.

    """
    # The original data mask (regardless of wrapped cells)
    C_mask = getattr(C, 'mask', None)
    if C.ndim == 3:
        # RGB(A) array.
        if not _MPL_38:
            raise ValueError("GeoQuadMesh wrapping for RGB(A) requires "
                             "Matplotlib v3.8 or later")

        # mask will need an extra trailing dimension
        mask = np.broadcast_to(mask[..., np.newaxis], C.shape)

    # create the masked array to be used with pcolormesh
    full_mask = mask if C_mask is None else mask | C_mask
    pcolormesh_data = ma.array(C, mask=full_mask)

    # create the masked array to be used with pcolor
    full_mask = ~mask if C_mask is None else ~mask | C_mask
    pcolor_data = ma.array(C, mask=full_mask)

    return pcolormesh_data, pcolor_data, ~mask


class GeoQuadMesh(QuadMesh):
    """
    A QuadMesh designed to help handle the case when the mesh is wrapped.

    """
    # No __init__ method here - most of the time a GeoQuadMesh will
    # come from GeoAxes.pcolormesh. These methods morph a QuadMesh by
    # fiddling with instance.__class__.

    def get_array(self):
        # Retrieve the array - use copy to avoid any chance of overwrite
        A = super().get_array().copy()
        # If the input array has a mask, retrieve the associated data
        if hasattr(self, '_wrapped_mask'):
            pcolor_data = self._wrapped_collection_fix.get_array()
            mask = self._wrapped_mask
            if not _MPL_38:
                A[mask] = pcolor_data
            else:
                if A.ndim == 3:  # RGB(A) data.  Need to broadcast mask.
                    mask = mask[:, :, np.newaxis]
                # np.copyto is not implemented for masked arrays so handle the
                # mask explicitly
                np.copyto(A.mask, pcolor_data.mask, where=mask)
                np.copyto(A, pcolor_data, where=mask)

        return A

    def set_array(self, A):
        # Check the shape is appropriate up front.
        if not _MPL_38:
            # Need to figure out existing shape from the coordinates.
            height, width = self._coordinates.shape[0:-1]
            if self._shading == 'flat':
                h, w = height - 1, width - 1
            else:
                h, w = height, width
        else:
            h, w = super().get_array().shape[:2]

        ok_shapes = [(h, w, 3), (h, w, 4), (h, w), (h * w,)]
        if A.shape not in ok_shapes:
            ok_shape_str = ' or '.join(map(str, ok_shapes))
            raise ValueError(
                f"A should have shape {ok_shape_str}, not {A.shape}")

        if A.ndim == 1:
            # Always use array with at least two dimensions.  This is
            # inconsistent with QuadMesh which stores whatever you give it, but
            # for the wrapped case we need to match the 2D mask.  Storing the
            # 2D array also allows us to calculate ok_shapes on subsequent
            # calls without using the private QuadMesh._shading attribute.
            A = A.reshape((h, w))

        # Only use the mask attribute if it is there.
        if hasattr(self, '_wrapped_mask'):

            # Update the pcolor data with the wrapped masked data
            A, pcolor_data, _ = _split_wrapped_mesh_data(A, self._wrapped_mask)

            if not _MPL_38:
                self._wrapped_collection_fix.set_array(
                    pcolor_data[self._wrapped_mask].ravel())
            else:
                self._wrapped_collection_fix.set_array(pcolor_data)

        # Now that we have prepared the collection data, call on
        # through to the underlying implementation.
        super().set_array(A)

    def set_clim(self, vmin=None, vmax=None):
        # Update _wrapped_collection_fix color limits if it is there.
        if hasattr(self, '_wrapped_collection_fix'):
            self._wrapped_collection_fix.set_clim(vmin, vmax)

        # Update color limits for the rest of the cells.
        super().set_clim(vmin, vmax)

    def get_datalim(self, transData):
        # Return the corners that were calculated in
        # the pcolormesh routine.
        return self._corners