File: coord.py

package info (click to toggle)
gammapy 1.0-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,776 kB
  • sloc: python: 58,736; makefile: 215; ansic: 69
file content (301 lines) | stat: -rw-r--r-- 9,845 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import copy
import numpy as np
from astropy import units as u
from astropy.coordinates import SkyCoord

__all__ = ["MapCoord"]


def skycoord_to_lonlat(skycoord, frame=None):
    """Convert SkyCoord to lon, lat, frame.

    Returns
    -------
    lon : `~numpy.ndarray`
        Longitude in degrees.
    lat : `~numpy.ndarray`
        Latitude in degrees.
    """
    if frame:
        skycoord = skycoord.transform_to(frame)

    return skycoord.data.lon.deg, skycoord.data.lat.deg, skycoord.frame.name


class MapCoord:
    """Represents a sequence of n-dimensional map coordinates.

    Contains coordinates for 2 spatial dimensions and an arbitrary
    number of additional non-spatial dimensions.

    For further information see :ref:`mapcoord`.

    Parameters
    ----------
    data : `dict` of `~numpy.ndarray`
        Dictionary of coordinate arrays.
    frame : {"icrs", "galactic", None}
        Spatial coordinate system.  If None then the coordinate system
        will be set to the native coordinate system of the geometry.
    match_by_name : bool
        Match coordinates to axes by name?
        If false coordinates will be matched by index.
    """

    def __init__(self, data, frame=None, match_by_name=True):
        if "lon" not in data or "lat" not in data:
            raise ValueError("data dictionary must contain axes named 'lon' and 'lat'.")

        self._data = {k: np.atleast_1d(v) for k, v in data.items()}
        self._frame = frame
        self._match_by_name = match_by_name

    def __getitem__(self, key):
        if isinstance(key, str):
            return self._data[key]
        else:
            return list(self._data.values())[key]

    def __setitem__(self, key, value):
        # TODO: check for broadcastability?
        self._data[key] = value

    def __iter__(self):
        return iter(self._data.values())

    @property
    def ndim(self):
        """Number of dimensions."""
        return len(self._data)

    @property
    def shape(self):
        """Coordinate array shape."""
        arrays = [_ for _ in self._data.values()]
        return np.broadcast(*arrays).shape

    @property
    def size(self):
        return np.prod(self.shape)

    @property
    def lon(self):
        """Longitude coordinate in degrees."""
        return self._data["lon"]

    @property
    def lat(self):
        """Latitude coordinate in degrees."""
        return self._data["lat"]

    @property
    def theta(self):
        """Theta co-latitude angle in radians."""
        theta = u.Quantity(self.lat, unit="deg", copy=False).to_value("rad")
        return np.pi / 2.0 - theta

    @property
    def phi(self):
        """Phi longitude angle in radians."""
        phi = u.Quantity(self.lon, unit="deg", copy=False).to_value("rad")
        return phi

    @property
    def frame(self):
        """Coordinate system (str)."""
        return self._frame

    @property
    def match_by_name(self):
        """Boolean flag: axis lookup by name (True) or index (False)."""
        return self._match_by_name

    @property
    def skycoord(self):
        return SkyCoord(self.lon, self.lat, unit="deg", frame=self.frame)

    @classmethod
    def _from_lonlat(cls, coords, frame=None, axis_names=None):
        """Create a `~MapCoord` from a tuple of coordinate vectors.

        The first two elements of the tuple should be longitude and latitude in degrees.

        Parameters
        ----------
        coords : tuple
            Tuple of `~numpy.ndarray`.

        Returns
        -------
        coord : `~MapCoord`
            A coordinates object.
        """
        if axis_names is None:
            axis_names = [f"axis{idx}" for idx in range(len(coords) - 2)]

        if isinstance(coords, (list, tuple)):
            coords_dict = {"lon": coords[0], "lat": coords[1]}
            for name, c in zip(axis_names, coords[2:]):
                coords_dict[name] = c
        else:
            raise ValueError("Unrecognized input type.")

        return cls(coords_dict, frame=frame, match_by_name=False)

    @classmethod
    def _from_tuple(cls, coords, frame=None, axis_names=None):
        """Create from tuple of coordinate vectors."""
        if isinstance(coords[0], (list, np.ndarray)) or np.isscalar(coords[0]):
            return cls._from_lonlat(coords, frame=frame, axis_names=axis_names)
        elif isinstance(coords[0], SkyCoord):
            lon, lat, frame = skycoord_to_lonlat(coords[0], frame=frame)
            coords = (lon, lat) + coords[1:]
            return cls._from_lonlat(coords, frame=frame, axis_names=axis_names)
        else:
            raise TypeError(f"Type not supported: {type(coords)!r}")

    @classmethod
    def _from_dict(cls, coords, frame=None):
        """Create from a dictionary of coordinate vectors."""
        if "lon" in coords and "lat" in coords:
            return cls(coords, frame=frame)
        elif "skycoord" in coords:
            lon, lat, frame = skycoord_to_lonlat(coords["skycoord"], frame=frame)
            coords_dict = {"lon": lon, "lat": lat}
            for k, v in coords.items():
                if k == "skycoord":
                    continue
                coords_dict[k] = v
            return cls(coords_dict, frame=frame)
        else:
            raise ValueError("coords dict must contain 'lon'/'lat' or 'skycoord'.")

    @classmethod
    def create(cls, data, frame=None, axis_names=None):
        """Create a new `~MapCoord` object.

        This method can be used to create either unnamed (with tuple input)
        or named (via dict input) axes.

        Parameters
        ----------
        data : tuple, dict, `~gammapy.maps.MapCoord` or `~astropy.coordinates.SkyCoord`
            Object containing coordinate arrays.
        frame : {"icrs", "galactic", None}, optional
            Set the coordinate system for longitude and latitude. If
            None longitude and latitude will be assumed to be in
            the coordinate system native to a given map geometry.
        axis_names : list of str
            Axis names use if a tuple is provided

        Examples
        --------
        >>> from astropy.coordinates import SkyCoord
        >>> from gammapy.maps import MapCoord

        >>> lon, lat = [1, 2], [2, 3]
        >>> skycoord = SkyCoord(lon, lat, unit='deg')
        >>> energy = [1000]
        >>> c = MapCoord.create((lon,lat))
        >>> c = MapCoord.create((skycoord,))
        >>> c = MapCoord.create((lon,lat,energy))
        >>> c = MapCoord.create(dict(lon=lon,lat=lat))
        >>> c = MapCoord.create(dict(lon=lon,lat=lat,energy=energy))
        >>> c = MapCoord.create(dict(skycoord=skycoord,energy=energy))
        """
        if isinstance(data, cls):
            if data.frame is None or frame == data.frame:
                return data
            else:
                return data.to_frame(frame)
        elif isinstance(data, dict):
            return cls._from_dict(data, frame=frame)
        elif isinstance(data, (list, tuple)):
            return cls._from_tuple(data, frame=frame, axis_names=axis_names)
        elif isinstance(data, SkyCoord):
            return cls._from_tuple((data,), frame=frame, axis_names=axis_names)
        else:
            raise TypeError(f"Unsupported input type: {type(data)!r}")

    def to_frame(self, frame):
        """Convert to a different coordinate frame.

        Parameters
        ----------
        frame : {"icrs", "galactic"}
            Coordinate system, either Galactic ("galactic") or Equatorial ("icrs").

        Returns
        -------
        coords : `~MapCoord`
            A coordinates object.
        """
        if frame == self.frame:
            return copy.deepcopy(self)
        else:
            lon, lat, frame = skycoord_to_lonlat(self.skycoord, frame=frame)
            data = copy.deepcopy(self._data)
            if isinstance(self.lon, u.Quantity):
                lon = u.Quantity(lon, unit="deg", copy=False)

            if isinstance(self.lon, u.Quantity):
                lat = u.Quantity(lat, unit="deg", copy=False)

            data["lon"] = lon
            data["lat"] = lat
            return self.__class__(data, frame, self._match_by_name)

    def apply_mask(self, mask):
        """Return a masked copy of this coordinate object.

        Parameters
        ----------
        mask : `~numpy.ndarray`
            Boolean mask.

        Returns
        -------
        coords : `~MapCoord`
            A coordinates object.
        """
        try:
            data = {k: v[mask] for k, v in self._data.items()}
        except IndexError:
            data = {}

            for name, coord in self._data.items():
                if name in ["lon", "lat"]:
                    data[name] = np.squeeze(coord)[mask]
                else:
                    data[name] = np.squeeze(coord, axis=-1)

        return self.__class__(data, self.frame, self._match_by_name)

    @property
    def flat(self):
        """Return flattened, valid coordinates"""
        coords = self.broadcasted
        is_finite = np.isfinite(coords[0])
        return coords.apply_mask(is_finite)

    @property
    def broadcasted(self):
        """Return broadcasted coords"""
        vals = np.broadcast_arrays(*self._data.values(), subok=True)
        data = dict(zip(self._data.keys(), vals))
        return self.__class__(
            data=data, frame=self.frame, match_by_name=self._match_by_name
        )

    def copy(self):
        """Copy `MapCoord` object."""
        return copy.deepcopy(self)

    def __repr__(self):
        return (
            f"{self.__class__.__name__}\n\n"
            f"\taxes     : {list(self._data.keys())}\n"
            f"\tshape    : {self.shape[::-1]}\n"
            f"\tndim     : {self.ndim}\n"
            f"\tframe : {self.frame}\n"
        )