1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
|
from __future__ import annotations
import torch
from typing import Optional, Union, Tuple
from torch.linalg import * # noqa: F403
# torch.linalg doesn't define __all__
# from torch.linalg import __all__ as linalg_all
from torch import linalg as torch_linalg
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
# outer is implemented in torch but aren't in the linalg namespace
from torch import outer
from ._aliases import _fix_promotion, sum
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
from ._typing import Array, DType
from ..common._typing import JustInt, JustFloat
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
# torch.cross also does not support broadcasting when it would add new
# dimensions https://github.com/pytorch/pytorch/issues/39656
def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
x1, x2 = torch.broadcast_tensors(x1, x2)
return torch_linalg.cross(x1, x2, dim=axis)
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array:
from ._aliases import isdtype
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
x1_ = torch.moveaxis(x1, axis, -1)
x2_ = torch.moveaxis(x2, axis, -1)
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
def solve(x1: Array, x2: Array, /, **kwargs) -> Array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# whenever
# 1. x1.ndim - 1 == x2.ndim
# 2. x1.shape[:-1] == x2.shape
#
# See linalg_solve_is_vector_rhs in
# aten/src/ATen/native/LinearAlgebraUtils.h and
# TORCH_META_FUNC(_linalg_solve_ex) in
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
#
# The easiest way to work around this is to prepend a size 1 dimension to
# x2, since x2 is already one dimension less than x1.
#
# See https://github.com/pytorch/pytorch/issues/52915
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
x2 = x2[None]
return torch.linalg.solve(x1, x2, **kwargs)
# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
def vector_norm(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
# JustFloat stands for inf | -inf, which are not valid for Literal
ord: JustInt | JustFloat = 2,
**kwargs,
) -> Array:
# torch.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
out = kwargs.get('out')
if out is None:
dtype = None
if x.dtype == torch.complex64:
dtype = torch.float32
elif x.dtype == torch.complex128:
dtype = torch.float64
out = torch.zeros_like(x, dtype=dtype)
# The norm of a single scalar works out to abs(x) in every case except
# for ord=0, which is x != 0.
if ord == 0:
out[:] = (x != 0)
else:
out[:] = torch.abs(x)
return out
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
_all_ignore = ['torch_linalg', 'sum']
del linalg_all
def __dir__() -> list[str]:
return __all__
|