1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
|
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import numpy as np
from ..extern.six.moves import range, zip
def _searchsorted(array, val, side='left'):
'''
Call np.searchsorted or use a custom binary
search if necessary.
'''
if hasattr(array, 'searchsorted'):
return array.searchsorted(val, side=side)
# Python binary search
begin = 0
end = len(array)
while begin < end:
mid = (begin + end) // 2
if val > array[mid]:
begin = mid + 1
elif val < array[mid]:
end = mid
elif side == 'right':
begin = mid + 1
else:
end = mid
return begin
class SortedArray(object):
'''
Implements a sorted array container using
a list of numpy arrays.
Parameters
----------
data : Table
Sorted columns of the original table
row_index : Column object
Row numbers corresponding to data columns
unique : bool (defaults to False)
Whether the values of the index must be unique
'''
def __init__(self, data, row_index, unique=False):
self.data = data
self.row_index = row_index
self.num_cols = len(getattr(data, 'colnames', []))
self.unique = unique
@property
def cols(self):
return self.data.columns.values()
def add(self, key, row):
'''
Add a new entry to the sorted array.
Parameters
----------
key : tuple
Column values at the given row
row : int
Row number
'''
pos = self.find_pos(key, row) # first >= key
if self.unique and 0 <= pos < len(self.row_index) and \
all(self.data[pos][i] == key[i] for i in range(len(key))):
# already exists
raise ValueError('Cannot add duplicate value "{0}" in a '
'unique index'.format(key))
self.data.insert_row(pos, key)
self.row_index = self.row_index.insert(pos, row)
def _get_key_slice(self, i, begin, end):
'''
Retrieve the ith slice of the sorted array
from begin to end.
'''
if i < self.num_cols:
return self.cols[i][begin:end]
else:
return self.row_index[begin:end]
def find_pos(self, key, data, exact=False):
'''
Return the index of the largest key in data greater than or
equal to the given key, data pair.
Parameters
----------
key : tuple
Column key
data : int
Row number
exact : bool
If True, return the index of the given key in data
or -1 if the key is not present.
'''
begin = 0
end = len(self.row_index)
num_cols = self.num_cols
if not self.unique:
# consider the row value as well
key = key + (data,)
num_cols += 1
# search through keys in lexicographic order
for i in range(num_cols):
key_slice = self._get_key_slice(i, begin, end)
t = _searchsorted(key_slice, key[i])
# t is the smallest index >= key[i]
if exact and (t == len(key_slice) or key_slice[t] != key[i]):
# no match
return -1
elif t == len(key_slice) or (t == 0 and len(key_slice) > 0 and
key[i] < key_slice[0]):
# too small or too large
return begin + t
end = begin + _searchsorted(key_slice, key[i], side='right')
begin += t
if begin >= len(self.row_index): # greater than all keys
return begin
return begin
def find(self, key):
'''
Find all rows matching the given key.
Parameters
----------
key : tuple
Column values
Returns
-------
matching_rows : list
List of rows matching the input key
'''
begin = 0
end = len(self.row_index)
# search through keys in lexicographic order
for i in range(self.num_cols):
key_slice = self._get_key_slice(i, begin, end)
t = _searchsorted(key_slice, key[i])
# t is the smallest index >= key[i]
if t == len(key_slice) or key_slice[t] != key[i]:
# no match
return []
elif t == 0 and len(key_slice) > 0 and key[i] < key_slice[0]:
# too small or too large
return []
end = begin + _searchsorted(key_slice, key[i], side='right')
begin += t
if begin >= len(self.row_index): # greater than all keys
return []
return self.row_index[begin:end]
def range(self, lower, upper, bounds):
'''
Find values in the given range.
Parameters
----------
lower : tuple
Lower search bound
upper : tuple
Upper search bound
bounds : tuple (x, y) of bools
Indicates whether the search should be inclusive or
exclusive with respect to the endpoints. The first
argument x corresponds to an inclusive lower bound,
and the second argument y to an inclusive upper bound.
'''
lower_pos = self.find_pos(lower, 0)
upper_pos = self.find_pos(upper, 0)
if lower_pos == len(self.row_index):
return []
lower_bound = tuple([col[lower_pos] for col in self.cols])
if not bounds[0] and lower_bound == lower:
lower_pos += 1 # data[lower_pos] > lower
# data[lower_pos] >= lower
# data[upper_pos] >= upper
if upper_pos < len(self.row_index):
upper_bound = tuple([col[upper_pos] for col in self.cols])
if not bounds[1] and upper_bound == upper:
upper_pos -= 1 # data[upper_pos] < upper
elif upper_bound > upper:
upper_pos -= 1 # data[upper_pos] <= upper
return self.row_index[lower_pos:upper_pos + 1]
def remove(self, key, data):
'''
Remove the given entry from the sorted array.
Parameters
----------
key : tuple
Column values
data : int
Row number
Returns
-------
successful : bool
Whether the entry was successfully removed
'''
pos = self.find_pos(key, data, exact=True)
if pos == -1: # key not found
return False
self.data.remove_row(pos)
keep_mask = np.ones(len(self.row_index), dtype=np.bool)
keep_mask[pos] = False
self.row_index = self.row_index[keep_mask]
return True
def shift_left(self, row):
'''
Decrement all row numbers greater than the input row.
Parameters
----------
row : int
Input row number
'''
self.row_index[self.row_index > row] -= 1
def shift_right(self, row):
'''
Increment all row numbers greater than or equal to the input row.
Parameters
----------
row : int
Input row number
'''
self.row_index[self.row_index >= row] += 1
def replace_rows(self, row_map):
'''
Replace all rows with the values they map to in the
given dictionary. Any rows not present as keys in
the dictionary will have their entries deleted.
Parameters
----------
row_map : dict
Mapping of row numbers to new row numbers
'''
num_rows = len(row_map)
keep_rows = np.zeros(len(self.row_index), dtype=np.bool)
tagged = 0
for i, row in enumerate(self.row_index):
if row in row_map:
keep_rows[i] = True
tagged += 1
if tagged == num_rows:
break
self.data = self.data[keep_rows]
self.row_index = np.array(
[row_map[x] for x in self.row_index[keep_rows]])
def items(self):
'''
Retrieve all array items as a list of pairs of the form
[(key, [row 1, row 2, ...]), ...]
'''
array = []
last_key = None
for i, key in enumerate(zip(*self.data.columns.values())):
row = self.row_index[i]
if key == last_key:
array[-1][1].append(row)
else:
last_key = key
array.append((key, [row]))
return array
def sort(self):
'''
Make row order align with key order.
'''
self.row_index = np.arange(len(self.row_index))
def sorted_data(self):
'''
Return rows in sorted order.
'''
return self.row_index
def __getitem__(self, item):
'''
Return a sliced reference to this sorted array.
Parameters
----------
item : slice
Slice to use for referencing
'''
return SortedArray(self.data[item], self.row_index[item])
def __repr__(self):
t = self.data.copy()
t['rows'] = self.row_index
return str(t)
def __str__(self):
return repr(self)
|