1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
|
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from ..extern import six
from ..extern.six.moves import zip
import platform
import warnings
import numpy as np
from .index import get_index
from ..utils.exceptions import AstropyUserWarning
__all__ = ['TableGroups', 'ColumnGroups']
def table_group_by(table, keys):
# index copies are unnecessary and slow down _table_group_by
with table.index_mode('discard_on_copy'):
return _table_group_by(table, keys)
def _table_group_by(table, keys):
"""
Get groups for ``table`` on specified ``keys``.
Parameters
----------
table : `Table`
Table to group
keys : str, list of str, `Table`, or Numpy array
Grouping key specifier
Returns
-------
grouped_table : Table object with groups attr set accordingly
"""
from .table import Table
# Pre-convert string to tuple of strings, or Table to the underlying structured array
if isinstance(keys, six.string_types):
keys = (keys,)
if isinstance(keys, (list, tuple)):
for name in keys:
if name not in table.colnames:
raise ValueError('Table does not have key column {0!r}'.format(name))
if table.masked and np.any(table[name].mask):
raise ValueError('Missing values in key column {0!r} are not allowed'.format(name))
keys = tuple(keys)
table_keys = table[keys]
grouped_by_table_cols = True # Grouping keys are columns from the table being grouped
elif isinstance(keys, (np.ndarray, Table)):
table_keys = keys
if len(table_keys) != len(table):
raise ValueError('Input keys array length {0} does not match table length {1}'
.format(len(table_keys), len(table)))
grouped_by_table_cols = False # Grouping key(s) are external
else:
raise TypeError('Keys input must be string, list, tuple or numpy array, but got {0}'
.format(type(keys)))
try:
# take advantage of index internal sort if possible
table_index = get_index(table, table_keys) if \
isinstance(table_keys, Table) else None
if table_index is not None:
idx_sort = table_index.sorted_data()
else:
idx_sort = table_keys.argsort(kind='mergesort')
stable_sort = True
except TypeError:
# Some versions (likely 1.6 and earlier) of numpy don't support
# 'mergesort' for all data types. MacOSX (Darwin) doesn't have a stable
# sort by default, nor does Windows, while Linux does (or appears to).
idx_sort = table_keys.argsort()
stable_sort = platform.system() not in ('Darwin', 'Windows')
table_keys = table_keys[idx_sort]
# Get all keys
diffs = np.concatenate(([True], table_keys[1:] != table_keys[:-1], [True]))
indices = np.flatnonzero(diffs)
# If the sort is not stable (preserves original table order) then sort idx_sort in
# place within each group.
if not stable_sort:
for i0, i1 in zip(indices[:-1], indices[1:]):
idx_sort[i0:i1].sort()
# Make a new table and set the _groups to the appropriate TableGroups object.
# Take the subset of the original keys at the indices values (group boundaries).
out = table.__class__(table[idx_sort])
out_keys = table_keys[indices[:-1]]
if isinstance(out_keys, Table):
out_keys.meta['grouped_by_table_cols'] = grouped_by_table_cols
out._groups = TableGroups(out, indices=indices, keys=out_keys)
return out
def column_group_by(column, keys):
"""
Get groups for ``column`` on specified ``keys``
Parameters
----------
column : Column object
Column to group
keys : Table or Numpy array of same length as col
Grouping key specifier
Returns
-------
grouped_column : Column object with groups attr set accordingly
"""
from .table import Table
if isinstance(keys, Table):
keys = keys.as_array()
if not isinstance(keys, np.ndarray):
raise TypeError('Keys input must be numpy array, but got {0}'
.format(type(keys)))
if len(keys) != len(column):
raise ValueError('Input keys array length {0} does not match column length {1}'
.format(len(keys), len(column)))
# take advantage of table or column indices, if possible
index = None
if isinstance(keys, Table):
index = get_index(keys)
elif hasattr(keys, 'indices') and keys.indices:
index = keys.indices[0]
if index is not None:
idx_sort = index.sorted_data()
else:
idx_sort = keys.argsort()
keys = keys[idx_sort]
# Get all keys
diffs = np.concatenate(([True], keys[1:] != keys[:-1], [True]))
indices = np.flatnonzero(diffs)
# Make a new column and set the _groups to the appropriate ColumnGroups object.
# Take the subset of the original keys at the indices values (group boundaries).
out = column.__class__(column[idx_sort])
out._groups = ColumnGroups(out, indices=indices, keys=keys[indices[:-1]])
return out
class BaseGroups(object):
"""
A class to represent groups within a table of heterogeneous data.
- ``keys``: key values corresponding to each group
- ``indices``: index values in parent table or column corresponding to group boundaries
- ``aggregate()``: method to create new table by aggregating within groups
"""
@property
def parent(self):
return self.parent_column if isinstance(self, ColumnGroups) else self.parent_table
def __iter__(self):
self._iter_index = 0
return self
def next(self):
ii = self._iter_index
if ii < len(self.indices) - 1:
i0, i1 = self.indices[ii], self.indices[ii + 1]
self._iter_index += 1
return self.parent[i0:i1]
else:
raise StopIteration
__next__ = next
def __getitem__(self, item):
parent = self.parent
if isinstance(item, (int, np.integer)):
i0, i1 = self.indices[item], self.indices[item + 1]
out = parent[i0:i1]
out.groups._keys = parent.groups.keys[item]
else:
indices0, indices1 = self.indices[:-1], self.indices[1:]
try:
i0s, i1s = indices0[item], indices1[item]
except Exception:
raise TypeError('Index item for groups attribute must be a slice, '
'numpy mask or int array')
mask = np.zeros(len(parent), dtype=np.bool)
# Is there a way to vectorize this in numpy?
for i0, i1 in zip(i0s, i1s):
mask[i0:i1] = True
out = parent[mask]
out.groups._keys = parent.groups.keys[item]
out.groups._indices = np.concatenate([[0], np.cumsum(i1s - i0s)])
return out
def __repr__(self):
return '<{0} indices={1}>'.format(self.__class__.__name__, self.indices)
def __len__(self):
return len(self.indices) - 1
class ColumnGroups(BaseGroups):
def __init__(self, parent_column, indices=None, keys=None):
self.parent_column = parent_column # parent Column
self.parent_table = parent_column.parent_table
self._indices = indices
self._keys = keys
@property
def indices(self):
# If the parent column is in a table then use group indices from table
if self.parent_table:
return self.parent_table.groups.indices
else:
if self._indices is None:
return np.array([0, len(self.parent_column)])
else:
return self._indices
@property
def keys(self):
# If the parent column is in a table then use group indices from table
if self.parent_table:
return self.parent_table.groups.keys
else:
return self._keys
def aggregate(self, func):
from .column import MaskedColumn
i0s, i1s = self.indices[:-1], self.indices[1:]
par_col = self.parent_column
masked = isinstance(par_col, MaskedColumn)
reduceat = hasattr(func, 'reduceat')
sum_case = func is np.sum
mean_case = func is np.mean
try:
if not masked and (reduceat or sum_case or mean_case):
if mean_case:
vals = np.add.reduceat(par_col, i0s) / np.diff(self.indices)
else:
if sum_case:
func = np.add
vals = func.reduceat(par_col, i0s)
else:
vals = np.array([func(par_col[i0: i1]) for i0, i1 in zip(i0s, i1s)])
except Exception:
raise TypeError("Cannot aggregate column '{0}' with type '{1}'"
.format(par_col.info.name,
par_col.info.dtype))
out = par_col.__class__(data=vals,
name=par_col.info.name,
description=par_col.info.description,
unit=par_col.info.unit,
format=par_col.info.format,
meta=par_col.info.meta)
return out
def filter(self, func):
"""
Filter groups in the Column based on evaluating function ``func`` on each
group sub-table.
The function which is passed to this method must accept one argument:
- ``column`` : `Column` object
It must then return either `True` or `False`. As an example, the following
will select all column groups with only positive values::
def all_positive(column):
if np.any(column < 0):
return False
return True
Parameters
----------
func : function
Filter function
Returns
-------
out : Column
New column with the aggregated rows.
"""
mask = np.empty(len(self), dtype=np.bool)
for i, group_column in enumerate(self):
mask[i] = func(group_column)
return self[mask]
class TableGroups(BaseGroups):
def __init__(self, parent_table, indices=None, keys=None):
self.parent_table = parent_table # parent Table
self._indices = indices
self._keys = keys
@property
def key_colnames(self):
"""
Return the names of columns in the parent table that were used for grouping.
"""
# If the table was grouped by key columns *in* the table then treat those columns
# differently in aggregation. In this case keys will be a Table with
# keys.meta['grouped_by_table_cols'] == True. Keys might not be a Table so we
# need to handle this.
grouped_by_table_cols = getattr(self.keys, 'meta', {}).get('grouped_by_table_cols', False)
return self.keys.colnames if grouped_by_table_cols else ()
@property
def indices(self):
if self._indices is None:
return np.array([0, len(self.parent_table)])
else:
return self._indices
def aggregate(self, func):
"""
Aggregate each group in the Table into a single row by applying the reduction
function ``func`` to group values in each column.
Parameters
----------
func : function
Function that reduces an array of values to a single value
Returns
-------
out : Table
New table with the aggregated rows.
"""
i0s, i1s = self.indices[:-1], self.indices[1:]
out_cols = []
parent_table = self.parent_table
for col in six.itervalues(parent_table.columns):
# For key columns just pick off first in each group since they are identical
if col.info.name in self.key_colnames:
new_col = col.take(i0s)
else:
try:
new_col = col.groups.aggregate(func)
except TypeError as err:
warnings.warn(six.text_type(err), AstropyUserWarning)
continue
out_cols.append(new_col)
return parent_table.__class__(out_cols, meta=parent_table.meta)
def filter(self, func):
"""
Filter groups in the Table based on evaluating function ``func`` on each
group sub-table.
The function which is passed to this method must accept two arguments:
- ``table`` : `Table` object
- ``key_colnames`` : tuple of column names in ``table`` used as keys for grouping
It must then return either `True` or `False`. As an example, the following
will select all table groups with only positive values in the non-key columns::
def all_positive(table, key_colnames):
colnames = [name for name in table.colnames if name not in key_colnames]
for colname in colnames:
if np.any(table[colname] < 0):
return False
return True
Parameters
----------
func : function
Filter function
Returns
-------
out : Table
New table with the aggregated rows.
"""
mask = np.empty(len(self), dtype=np.bool)
key_colnames = self.key_colnames
for i, group_table in enumerate(self):
mask[i] = func(group_table, key_colnames)
return self[mask]
@property
def keys(self):
return self._keys
|