File: util.py

package info (click to toggle)
glueviz 0.9.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 17,180 kB
  • ctags: 6,728
  • sloc: python: 37,111; makefile: 134; sh: 60
file content (368 lines) | stat: -rw-r--r-- 11,197 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
from __future__ import absolute_import, division, print_function

import logging
from itertools import count
from functools import partial


import numpy as np
import pandas as pd

from matplotlib.ticker import AutoLocator, MaxNLocator, LogLocator
from matplotlib.ticker import (LogFormatterMathtext, ScalarFormatter,
                               FuncFormatter)

__all__ = ["relim", "split_component_view", "join_component_view",
           "facet_subsets", "colorize_subsets", "disambiguate",
           "row_lookup", 'small_view', 'small_view_array', 'visible_limits',
           'tick_linker', 'update_ticks']


def relim(lo, hi, log=False):
    logging.getLogger(__name__).debug("Inputs to relim: %r %r", lo, hi)
    x, y = lo, hi
    if log:
        if lo < 0:
            x = 1e-5
        if hi < 0:
            y = 1e5
        return x * .95, y * 1.05
    delta = y - x
    return (x - .02 * delta, y + .02 * delta)


def split_component_view(arg):
    """Split the input to data or subset.__getitem__ into its pieces.

    :param arg: The input passed to data or subset.__getitem__.
                Assumed to be either a scalar or tuple

    :rtype: tuple

    The first item is the Component selection (a ComponentID or
    string)

    The second item is a view (tuple of slices, slice scalar, or view
    object)
    """
    if isinstance(arg, tuple):
        if len(arg) == 1:
            raise TypeError("Expected a scalar or >length-1 tuple, "
                            "got length-1 tuple")
        if len(arg) == 2:
            return arg[0], arg[1]
        return arg[0], arg[1:]
    else:
        return arg, None


def join_component_view(component, view):
    """Pack a componentID and optional view into single tuple

    Returns an object compatible with data.__getitem__ and related
    methods.  Handles edge cases of when view is None, a scalar, a
    tuple, etc.

    :param component: ComponentID
    :param view: view into data, or None

    """
    if view is None:
        return component
    result = [component]
    try:
        result.extend(view)
    except TypeError:  # view is a scalar
        result = [component, view]

    return tuple(result)


def facet_subsets(data_collection, cid, lo=None, hi=None, steps=5,
                  prefix='', log=False):
    """Create a series of subsets that partition the values of
    a particular attribute into several bins

    This creates `steps` new subet groups, adds them to the data collection,
    and returns the list of newly created subset groups.

    :param data: DataCollection object to use
    :type data: :class:`~glue.core.data_collection.DataCollection`

    :param cid: ComponentID to facet on
    :type data: :class:`~glue.core.component_id.ComponentID`

    :param lo: The lower bound for the faceting. Defaults to minimum value
               in data
    :type lo: float

    :param hi: The upper bound for the faceting. Defaults to maximum
               value in data
    :type hi: float

    :param steps: The number of subsets to create. Defaults to 5
    :type steps: int

    :param prefix: If present, the new subset labels will begin with `prefix`
    :type prefix: str

    :param log: If True, space divisions logarithmically. Default=False
    :type log: bool

    :returns: List of :class:`~glue.core.subset_group.SubsetGroup` instances
              added to `data`

    Example::

        facet_subset(data, data.id['mass'], lo=0, hi=10, steps=2)

    creates 2 new subsets. The first represents the constraint 0 <=
    mass < 5. The second represents 5 <= mass <= 10::

        facet_subset(data, data.id['mass'], lo=10, hi=0, steps=2)

    Creates 2 new subsets. The first represents the constraint 10 >= x > 5
    The second represents 5 >= mass >= 0::

        facet_subset(data, data.id['mass'], lo=0, hi=10, steps=2, prefix='m')

    Labels the subsets ``m_1`` and ``m_2``.

    Note that the last range is inclusive on both sides. For example, if ``lo``
    is 0 and ``hi`` is 5, and ``steps`` is 5, then the intervals for the subsets
    are [0,1), [1,2), [2,3), [3,4), and [4,5].

    """
    from glue.core.exceptions import IncompatibleAttribute
    if lo is None or hi is None:
        for data in data_collection:
            try:
                vals = data[cid]
                break
            except IncompatibleAttribute:
                continue
        else:
            raise ValueError("Cannot infer data limits for ComponentID %s"
                             % cid)
        if lo is None:
            lo = np.nanmin(vals)
        if hi is None:
            hi = np.nanmax(vals)

    reverse = lo > hi
    if log:
        rng = np.logspace(np.log10(lo), np.log10(hi), steps + 1)
    else:
        rng = np.linspace(lo, hi, steps + 1)

    states = []
    labels = []
    for i in range(steps):

        # The if i < steps - 1 clauses are needed because the last interval
        # has to be inclusive on both sides.

        if reverse:
            if i < steps - 1:
                states.append((cid <= rng[i]) & (cid > rng[i + 1]))
                labels.append(prefix + '{0}<{1}<={2}'.format(rng[i + 1], cid, rng[i]))
            else:
                states.append((cid <= rng[i]) & (cid >= rng[i + 1]))
                labels.append(prefix + '{0}<={1}<={2}'.format(rng[i + 1], cid, rng[i]))

        else:
            if i < steps - 1:
                states.append((cid >= rng[i]) & (cid < rng[i + 1]))
                labels.append(prefix + '{0}<={1}<{2}'.format(rng[i], cid, rng[i + 1]))
            else:
                states.append((cid >= rng[i]) & (cid <= rng[i + 1]))
                labels.append(prefix + '{0}<={1}<={2}'.format(rng[i], cid, rng[i + 1]))

    result = []
    for lbl, s in zip(labels, states):
        sg = data_collection.new_subset_group(label=lbl, subset_state=s)
        result.append(sg)

    return result


def colorize_subsets(subsets, cmap, lo=0, hi=1):
    """Re-color a list of subsets according to a colormap

    :param subsets: List of subsets
    :param cmap: Matplotlib colormap instance
    :param lo: Start location in colormap. 0-1. Defaults to 0
    :param hi: End location in colormap. 0-1. Defaults to 1

    The colormap will be sampled at `len(subsets)` even intervals
    between `lo` and `hi`. The color at the `ith` interval will be
    applied to `subsets[i]`
    """

    from matplotlib import cm
    sm = cm.ScalarMappable(cmap=cmap)
    sm.norm.vmin = 0
    sm.norm.vmax = 1

    vals = np.linspace(lo, hi, len(subsets))
    rgbas = sm.to_rgba(vals)

    for color, subset in zip(rgbas, subsets):
        r, g, b, a = color
        r = int(255 * r)
        g = int(255 * g)
        b = int(255 * b)
        subset.style.color = '#%2.2x%2.2x%2.2x' % (r, g, b)


def disambiguate(label, taken):
    """If necessary, add a suffix to label to avoid name conflicts

    :param label: desired label
    :param taken: set of taken names

    Returns label if it is not in the taken set. Otherwise, returns
    label_NN where NN is the lowest integer such that label_NN not in taken.
    """
    if label not in taken:
        return label
    suffix = "_%2.2i"
    label = str(label)
    for i in count(1):
        candidate = label + (suffix % i)
        if candidate not in taken:
            return candidate


def row_lookup(data, categories):
    """
    Lookup which row in categories each data item is equal to

    :param data: array-like
    :param categories: array-like of unique values

    :returns: Float array.
              If result[i] is finite, then data[i] = categoreis[result[i]]
              Otherwise, data[i] is not in the categories list
    """

    # np.searchsorted doesn't work on mixed types in Python3

    ndata, ncat = len(data), len(categories)
    data = pd.DataFrame({'data': data, 'row': np.arange(ndata)})
    cats = pd.DataFrame({'categories': categories,
                         'cat_row': np.arange(ncat)})

    m = pd.merge(data, cats, left_on='data', right_on='categories')
    result = np.zeros(ndata, dtype=float) * np.nan
    result[np.array(m.row)] = m.cat_row
    return result


def small_view(data, attribute):
    """
    Extract a downsampled view from a dataset, for quick
    statistical summaries
    """
    shp = data.shape
    view = tuple([slice(None, None, np.intp(max(s / 50, 1))) for s in shp])
    return data[attribute, view]


def small_view_array(data):
    """
    Same as small_view, except using a numpy array as input
    """
    shp = data.shape
    view = tuple([slice(None, None, np.intp(max(s / 50, 1))) for s in shp])
    return np.asarray(data)[view]


def visible_limits(artists, axis):
    """
    Determines the data limits for the data in a set of artists.

    Ignores non-visible artists

    Assumes each artist as a get_data method wich returns a tuple of x,y

    Returns a tuple of min, max for the requested axis, or None if no data
    present

    :param artists: An iterable collection of artists
    :param axis: Which axis to compute. 0=xaxis, 1=yaxis
    """
    data = []
    for art in artists:
        if not art.visible:
            continue
        xy = art.get_data()
        assert isinstance(xy, tuple)
        val = xy[axis]
        if val.size > 0:
            data.append(xy[axis])

    if len(data) == 0:
        return
    data = np.hstack(data)
    if data.size == 0:
        return

    data = data[np.isfinite(data)]
    if data.size == 0:
        return

    lo, hi = np.nanmin(data), np.nanmax(data)
    if not np.isfinite(lo):
        return

    return lo, hi


def tick_linker(all_categories, pos, *args):
    try:
        pos = np.round(pos)
        return all_categories[int(pos)]
    except IndexError:
        return ''


def update_ticks(axes, coord, components, is_log):
    """
    Changes the axes to have the proper tick formatting based on the type of
    component.

    :param axes: A matplotlib axis object to alter
    :param coord: 'x' or 'y'
    :param components: A list() of components that are plotted along this axis
    :param is_log: Boolean for log-scale.
    :kwarg max_categories: The maximum number of categories to display.
    :return: None or #categories if components is Categorical
    """

    if coord == 'x':
        axis = axes.xaxis
    elif coord == 'y':
        axis = axes.yaxis
    else:
        raise TypeError("coord must be one of x,y")

    is_cat = all(comp.categorical for comp in components)
    if is_log:
        axis.set_major_locator(LogLocator())
        axis.set_major_formatter(LogFormatterMathtext())
    elif is_cat:
        all_categories = np.empty((0,), dtype=np.object)
        for comp in components:
            all_categories = np.union1d(comp.categories, all_categories)
        locator = MaxNLocator(10, integer=True)
        locator.view_limits(0, all_categories.shape[0])
        format_func = partial(tick_linker, all_categories)
        formatter = FuncFormatter(format_func)

        axis.set_major_locator(locator)
        axis.set_major_formatter(formatter)
        return all_categories.shape[0]
    else:
        axis.set_major_locator(AutoLocator())
        axis.set_major_formatter(ScalarFormatter())