from typing import cast

from ..types import ArrayXd
from ..util import get_torch_default_device, tensorflow2xp
from ..compat import torch, cupy, tensorflow


def cupy_tensorflow_allocator(size_in_bytes: int):
    """Function that can be passed into cupy.cuda.set_allocator, to have cupy
    allocate memory via TensorFlow. This is important when using the two libraries
    together, as otherwise OOM errors can occur when there's available memory
    sitting in the other library's pool.
    """
    size_in_bytes = max(1024, size_in_bytes)
    tensor = tensorflow.zeros((size_in_bytes // 4,), dtype=tensorflow.dtypes.float32)
    # We convert to cupy via dlpack, so that we can get a memory pointer.
    cupy_array = cast(ArrayXd, tensorflow2xp(tensor))
    address = int(cupy_array.data)
    # cupy has a neat class to help us here. Otherwise it will try to free.
    memory = cupy.cuda.memory.UnownedMemory(address, size_in_bytes, cupy_array)
    # Now return a new memory pointer.
    return cupy.cuda.memory.MemoryPointer(memory, 0)


def cupy_pytorch_allocator(size_in_bytes: int):
    device = get_torch_default_device()
    """Function that can be passed into cupy.cuda.set_allocator, to have cupy
    allocate memory via PyTorch. This is important when using the two libraries
    together, as otherwise OOM errors can occur when there's available memory
    sitting in the other library's pool.
    """
    # Cupy was having trouble with very small allocations?
    size_in_bytes = max(1024, size_in_bytes)
    # We use pytorch's underlying FloatStorage type to avoid overhead from
    # creating a whole Tensor.
    # This turns out to be way faster than making FloatStorage? Maybe
    # a Python vs C++ thing I guess?
    torch_tensor = torch.zeros(
        (size_in_bytes // 4,), requires_grad=False, device=device
    )
    # cupy has a neat class to help us here. Otherwise it will try to free.
    # I think this is a private API? It's not in the types.
    address = torch_tensor.data_ptr()  # type: ignore
    memory = cupy.cuda.memory.UnownedMemory(address, size_in_bytes, torch_tensor)
    # Now return a new memory pointer.
    return cupy.cuda.memory.MemoryPointer(memory, 0)
