1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
|
from __future__ import annotations
from types import ModuleType
from typing import Any
import numpy as np
from xarray.namedarray._typing import (
Default,
_arrayapi,
_Axes,
_Axis,
_default,
_Dim,
_DType,
_ScalarType,
_ShapeType,
_SupportsImag,
_SupportsReal,
)
from xarray.namedarray.core import NamedArray
def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType:
if isinstance(x._data, _arrayapi):
return x._data.__array_namespace__()
return np
# %% Creation Functions
def astype(
x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True
) -> NamedArray[_ShapeType, _DType]:
"""
Copies an array to a specified data type irrespective of Type Promotion Rules rules.
Parameters
----------
x : NamedArray
Array to cast.
dtype : _DType
Desired data type.
copy : bool, optional
Specifies whether to copy an array when the specified dtype matches the data
type of the input array x.
If True, a newly allocated array must always be returned.
If False and the specified dtype matches the data type of the input array,
the input array must be returned; otherwise, a newly allocated array must be
returned. Default: True.
Returns
-------
out : NamedArray
An array having the specified data type. The returned array must have the
same shape as x.
Examples
--------
>>> narr = NamedArray(("x",), np.asarray([1.5, 2.5]))
>>> narr
<xarray.NamedArray (x: 2)> Size: 16B
array([1.5, 2.5])
>>> astype(narr, np.dtype(np.int32))
<xarray.NamedArray (x: 2)> Size: 8B
array([1, 2], dtype=int32)
"""
if isinstance(x._data, _arrayapi):
xp = x._data.__array_namespace__()
return x._new(data=xp.astype(x._data, dtype, copy=copy))
# np.astype doesn't exist yet:
return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined]
# %% Elementwise Functions
def imag(
x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
/,
) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
"""
Returns the imaginary component of a complex number for each element x_i of the
input array x.
Parameters
----------
x : NamedArray
Input array. Should have a complex floating-point data type.
Returns
-------
out : NamedArray
An array containing the element-wise results. The returned array must have a
floating-point data type with the same floating-point precision as x
(e.g., if x is complex64, the returned array must have the floating-point
data type float32).
Examples
--------
>>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j]))
>>> imag(narr)
<xarray.NamedArray (x: 2)> Size: 16B
array([2., 4.])
"""
xp = _get_data_namespace(x)
out = x._new(data=xp.imag(x._data))
return out
def real(
x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
/,
) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
"""
Returns the real component of a complex number for each element x_i of the
input array x.
Parameters
----------
x : NamedArray
Input array. Should have a complex floating-point data type.
Returns
-------
out : NamedArray
An array containing the element-wise results. The returned array must have a
floating-point data type with the same floating-point precision as x
(e.g., if x is complex64, the returned array must have the floating-point
data type float32).
Examples
--------
>>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j]))
>>> real(narr)
<xarray.NamedArray (x: 2)> Size: 16B
array([1., 2.])
"""
xp = _get_data_namespace(x)
out = x._new(data=xp.real(x._data))
return out
# %% Manipulation functions
def expand_dims(
x: NamedArray[Any, _DType],
/,
*,
dim: _Dim | Default = _default,
axis: _Axis = 0,
) -> NamedArray[Any, _DType]:
"""
Expands the shape of an array by inserting a new dimension of size one at the
position specified by dims.
Parameters
----------
x :
Array to expand.
dim :
Dimension name. New dimension will be stored in the axis position.
axis :
(Not recommended) Axis position (zero-based). Default is 0.
Returns
-------
out :
An expanded output array having the same data type as x.
Examples
--------
>>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]]))
>>> expand_dims(x)
<xarray.NamedArray (dim_2: 1, x: 2, y: 2)> Size: 32B
array([[[1., 2.],
[3., 4.]]])
>>> expand_dims(x, dim="z")
<xarray.NamedArray (z: 1, x: 2, y: 2)> Size: 32B
array([[[1., 2.],
[3., 4.]]])
"""
xp = _get_data_namespace(x)
dims = x.dims
if dim is _default:
dim = f"dim_{len(dims)}"
d = list(dims)
d.insert(axis, dim)
out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis))
return out
def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]:
"""
Permutes the dimensions of an array.
Parameters
----------
x :
Array to permute.
axes :
Permutation of the dimensions of x.
Returns
-------
out :
An array with permuted dimensions. The returned array must have the same
data type as x.
"""
dims = x.dims
new_dims = tuple(dims[i] for i in axes)
if isinstance(x._data, _arrayapi):
xp = _get_data_namespace(x)
out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes))
else:
out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined]
return out
|