1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
"""
pint.facets.dask
~~~~~~~~~~~~~~~~
Adds pint the capability to interoperate with Dask
:copyright: 2022 by Pint Authors, see AUTHORS for more details.
:license: BSD, see LICENSE for more details.
"""
from __future__ import annotations
import functools
from typing import Generic
from ...compat import TypeAlias, compute, dask_array, persist, visualize
from ..plain import (
GenericPlainRegistry,
MagnitudeT,
PlainQuantity,
PlainUnit,
QuantityT,
UnitT,
)
def check_dask_array(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
if isinstance(self._magnitude, dask_array.Array):
return f(self, *args, **kwargs)
else:
msg = "Method {} only implemented for objects of {}, not {}".format(
f.__name__, dask_array.Array, self._magnitude.__class__
)
raise AttributeError(msg)
return wrapper
class DaskQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
# Dask.array.Array ducking
def __dask_graph__(self):
if isinstance(self._magnitude, dask_array.Array):
return self._magnitude.__dask_graph__()
return None
def __dask_keys__(self):
return self._magnitude.__dask_keys__()
def __dask_tokenize__(self):
from dask.base import tokenize
return (type(self), tokenize(self._magnitude), self.units)
@property
def __dask_optimize__(self):
return dask_array.Array.__dask_optimize__
@property
def __dask_scheduler__(self):
return dask_array.Array.__dask_scheduler__
def __dask_postcompute__(self):
func, args = self._magnitude.__dask_postcompute__()
return self._dask_finalize, (func, args, self.units)
def __dask_postpersist__(self):
func, args = self._magnitude.__dask_postpersist__()
return self._dask_finalize, (func, args, self.units)
def _dask_finalize(self, results, func, args, units):
values = func(results, *args)
return type(self)(values, units)
@check_dask_array
def compute(self, **kwargs):
"""Compute the Dask array wrapped by pint.PlainQuantity.
Parameters
----------
**kwargs : dict
Any keyword arguments to pass to ``dask.compute``.
Returns
-------
pint.PlainQuantity
A pint.PlainQuantity wrapped numpy array.
"""
(result,) = compute(self, **kwargs)
return result
@check_dask_array
def persist(self, **kwargs):
"""Persist the Dask Array wrapped by pint.PlainQuantity.
Parameters
----------
**kwargs : dict
Any keyword arguments to pass to ``dask.persist``.
Returns
-------
pint.PlainQuantity
A pint.PlainQuantity wrapped Dask array.
"""
(result,) = persist(self, **kwargs)
return result
@check_dask_array
def visualize(self, **kwargs):
"""Produce a visual representation of the Dask graph.
The graphviz library is required.
Parameters
----------
**kwargs : dict
Any keyword arguments to pass to ``dask.visualize``.
Returns
-------
"""
visualize(self, **kwargs)
class DaskUnit(PlainUnit):
pass
class GenericDaskRegistry(
Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
):
pass
class DaskRegistry(GenericDaskRegistry[DaskQuantity, DaskUnit]):
Quantity: TypeAlias = DaskQuantity
Unit: TypeAlias = DaskUnit
|