1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
|
from typing import Callable, Iterator, Optional, TypeVar
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
from torch.utils.data.datapipes.utils.common import (
_check_unpickable_fn,
_deprecation_warning,
StreamWrapper,
validate_input_col
)
__all__ = ["FilterIterDataPipe", ]
T_co = TypeVar('T_co', covariant=True)
@functional_datapipe('filter')
class FilterIterDataPipe(IterDataPipe[T_co]):
r"""
Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``).
Args:
datapipe: Iterable DataPipe being filtered
filter_fn: Customized function mapping an element to a boolean.
drop_empty_batches (Deprecated): By default, drops a batch if it is empty after filtering instead of keeping an empty list
input_col: Index or indices of data which ``filter_fn`` is applied, such as:
- ``None`` as default to apply ``filter_fn`` to the data directly.
- Integer(s) is used for list/tuple.
- Key(s) is used for dict.
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def is_even(n):
... return n % 2 == 0
>>> dp = IterableWrapper(range(5))
>>> filter_dp = dp.filter(filter_fn=is_even)
>>> list(filter_dp)
[0, 2, 4]
"""
datapipe: IterDataPipe
filter_fn: Callable
drop_empty_batches: bool
def __init__(
self,
datapipe: IterDataPipe,
filter_fn: Callable,
drop_empty_batches: Optional[bool] = None,
input_col=None,
) -> None:
super().__init__()
self.datapipe = datapipe
_check_unpickable_fn(filter_fn)
self.filter_fn = filter_fn # type: ignore[assignment]
if drop_empty_batches is None:
drop_empty_batches = True
else:
_deprecation_warning(
type(self).__name__,
deprecation_version="1.12",
removal_version="1.14",
old_argument_name="drop_empty_batches",
)
self.drop_empty_batches = drop_empty_batches
self.input_col = input_col
validate_input_col(filter_fn, input_col)
def _apply_filter_fn(self, data) -> bool:
if self.input_col is None:
return self.filter_fn(data)
elif isinstance(self.input_col, (list, tuple)):
args = tuple(data[col] for col in self.input_col)
return self.filter_fn(*args)
else:
return self.filter_fn(data[self.input_col])
def __iter__(self) -> Iterator[T_co]:
for data in self.datapipe:
filtered = self._returnIfTrue(data)
if self._isNonEmpty(filtered):
yield filtered
else:
StreamWrapper.close_streams(data)
def _returnIfTrue(self, data):
condition = self._apply_filter_fn(data)
if df_wrapper.is_column(condition):
# We are operating on DataFrames filter here
result = []
for idx, mask in enumerate(df_wrapper.iterate(condition)):
if mask:
result.append(df_wrapper.get_item(data, idx))
if len(result):
return df_wrapper.concat(result)
else:
return None
if not isinstance(condition, bool):
raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe, got", type(condition))
if condition:
return data
def _isNonEmpty(self, data):
if df_wrapper.is_dataframe(data):
return True
r = data is not None and \
not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches)
return r
|