1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
|
# mypy: allow-untyped-defs
"""This module converts objects into numpy array."""
import numpy as np
import torch
def make_np(x):
"""
Convert an object into numpy array.
Args:
x: An instance of torch tensor
Returns:
numpy.array: Numpy array
"""
if isinstance(x, np.ndarray):
return x
if np.isscalar(x):
return np.array([x])
if isinstance(x, torch.Tensor):
return _prepare_pytorch(x)
raise NotImplementedError(
f"Got {type(x)}, but numpy array or torch tensor are expected."
)
def _prepare_pytorch(x):
if x.dtype == torch.bfloat16:
x = x.to(torch.float16)
x = x.detach().cpu().numpy()
return x
|