1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
# -*- coding: utf-8 -*-
import itertools
from functools import reduce
import numpy as np
MIN_NB_POINTS = 5
KEY_INDEX = np.concatenate((range(5), range(-1, -6, -1)))
def intersection(left, right):
"""Intersection of two streamlines dict (see hash_streamlines)"""
return {k: v for k, v in left.items() if k in right}
def difference(left, right):
"""Difference of two streamlines dict (see hash_streamlines)"""
return {k: v for k, v in left.items() if k not in right}
def union(left, right):
"""Union of two streamlines dict (see hash_streamlines)"""
result = right.copy()
result.update(left)
return result
def get_streamline_key(streamline, precision=None):
"""Produces a key using a hash from a streamline using a few points only and
the desired precision
Parameters
----------
streamlines: ndarray
A single streamline (N,3)
precision: int, optional
The number of decimals to keep when hashing the points of the
streamlines. Allows a soft comparison of streamlines. If None, no
rounding is performed.
Returns
-------
Value of the hash of the first/last MIN_NB_POINTS points of the streamline.
"""
# Use just a few data points as hash key. I could use all the data of
# the streamlines, but then the complexity grows with the number of
# points.
if len(streamline) < MIN_NB_POINTS:
key = streamline.copy()
else:
key = streamline[KEY_INDEX].copy()
if precision is not None:
key = np.round(key, precision)
key.flags.writeable = False
return key.data.tobytes()
def hash_streamlines(streamlines, start_index=0, precision=None):
"""Produces a dict from streamlines
Produces a dict from streamlines by using the points as keys and the
indices of the streamlines as values.
Parameters
----------
streamlines: list of ndarray
The list of streamlines used to produce the dict.
start_index: int, optional
The index of the first streamline. 0 by default.
precision: int, optional
The number of decimals to keep when hashing the points of the
streamlines. Allows a soft comparison of streamlines. If None, no
rounding is performed.
Returns
-------
A dict where the keys are streamline points and the values are indices
starting at start_index.
"""
keys = [get_streamline_key(s, precision) for s in streamlines]
return {k: i for i, k in enumerate(keys, start_index)}
def perform_streamlines_operation(operation, streamlines, precision=0):
"""Peforms an operation on a list of list of streamlines
Given a list of list of streamlines, this function applies the operation
to the first two lists of streamlines. The result in then used recursively
with the third, fourth, etc. lists of streamlines.
A valid operation is any function that takes two streamlines dict as input
and produces a new streamlines dict (see hash_streamlines). Union,
difference, and intersection are valid examples of operations.
Parameters
----------
operation: callable
A callable that takes two streamlines dicts as inputs and preduces a
new streamline dict.
streamlines: list of list of streamlines
The streamlines used in the operation.
precision: int, optional
The number of decimals to keep when hashing the points of the
streamlines. Allows a soft comparison of streamlines. If None, no
rounding is performed.
Returns
-------
streamlines: list of `nib.streamline.ArraySequence`
The streamlines obtained after performing the operation on all the
input streamlines.
indices: np.ndarray
The indices of the streamlines that are used in the output.
"""
# Hash the streamlines using the desired precision.
indices = np.cumsum([0] + [len(s) for s in streamlines[:-1]])
hashes = [hash_streamlines(s, i, precision) for
s, i in zip(streamlines, indices)]
# Perform the operation on the hashes and get the output streamlines.
to_keep = reduce(operation, hashes)
all_streamlines = list(itertools.chain(*streamlines))
indices = np.array(sorted(to_keep.values())).astype(np.uint32)
streamlines = [all_streamlines[i] for i in indices]
return streamlines, indices
|