import os
import re
import sys
import tempfile
from io import BytesIO

import numpy as np
import pytest
import torch
import torchvision.transforms.functional as F
import torchvision.utils as utils
from common_utils import assert_equal, cpu_and_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
from torchvision.transforms.v2.functional import to_dtype


PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))

boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)


def test_make_grid_not_inplace():
    t = torch.rand(5, 3, 10, 10)
    t_clone = t.clone()

    utils.make_grid(t, normalize=False)
    assert_equal(t, t_clone, msg="make_grid modified tensor in-place")

    utils.make_grid(t, normalize=True, scale_each=False)
    assert_equal(t, t_clone, msg="make_grid modified tensor in-place")

    utils.make_grid(t, normalize=True, scale_each=True)
    assert_equal(t, t_clone, msg="make_grid modified tensor in-place")


def test_normalize_in_make_grid():
    t = torch.rand(5, 3, 10, 10) * 255
    norm_max = torch.tensor(1.0)
    norm_min = torch.tensor(0.0)

    grid = utils.make_grid(t, normalize=True)
    grid_max = torch.max(grid)
    grid_min = torch.min(grid)

    # Rounding the result to one decimal for comparison
    n_digits = 1
    rounded_grid_max = torch.round(grid_max * 10**n_digits) / (10**n_digits)
    rounded_grid_min = torch.round(grid_min * 10**n_digits) / (10**n_digits)

    assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1")
    assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0")


@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image():
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
        t = torch.rand(2, 3, 64, 64)
        utils.save_image(t, f.name)
        assert os.path.exists(f.name), "The image is not present after save"


@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_single_pixel():
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
        t = torch.rand(1, 3, 1, 1)
        utils.save_image(t, f.name)
        assert os.path.exists(f.name), "The pixel image is not present after save"


@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_file_object():
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
        t = torch.rand(2, 3, 64, 64)
        utils.save_image(t, f.name)
        img_orig = Image.open(f.name)
        fp = BytesIO()
        utils.save_image(t, fp, format="png")
        img_bytes = Image.open(fp)
        assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")


@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_single_pixel_file_object():
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
        t = torch.rand(1, 3, 1, 1)
        utils.save_image(t, f.name)
        img_orig = Image.open(f.name)
        fp = BytesIO()
        utils.save_image(t, fp, format="png")
        img_bytes = Image.open(fp)
        assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")


def test_draw_boxes():
    img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
    img_cp = img.clone()
    boxes_cp = boxes.clone()
    labels = ["a", "b", "c", "d"]
    colors = ["green", "#FF00FF", (0, 255, 0), "red"]
    result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)

    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    if PILLOW_VERSION >= (10, 1):
        # The reference image is only valid for new PIL versions
        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        assert_equal(result, expected)

    # Check if modification is not in place
    assert_equal(boxes, boxes_cp)
    assert_equal(img, img_cp)


@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
def test_draw_boxes_with_coloured_labels():
    img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
    labels = ["a", "b", "c", "d"]
    colors = ["green", "#FF00FF", (0, 255, 0), "red"]
    label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
    result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True, label_colors=label_colors)

    path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_colors.png"
    )
    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)


@pytest.mark.parametrize("fill", [True, False])
def test_draw_boxes_dtypes(fill):
    img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
    out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill)

    assert img_uint8 is not out_uint8
    assert out_uint8.dtype == torch.uint8

    img_float = to_dtype(img_uint8, torch.float, scale=True)
    out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill)

    assert img_float is not out_float
    assert out_float.is_floating_point()

    torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)


@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
def test_draw_boxes_colors(colors):
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)

    with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
        utils.draw_bounding_boxes(image=img, boxes=boxes, colors=[])


def test_draw_boxes_vanilla():
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
    boxes_cp = boxes.clone()
    result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")

    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)
    # Check if modification is not in place
    assert_equal(boxes, boxes_cp)
    assert_equal(img, img_cp)


def test_draw_boxes_grayscale():
    img = torch.full((1, 4, 4), fill_value=255, dtype=torch.uint8)
    boxes = torch.tensor([[0, 0, 3, 3]], dtype=torch.int64)
    bboxed_img = utils.draw_bounding_boxes(image=img, boxes=boxes, colors=["#1BBC9B"])
    assert bboxed_img.size(0) == 3


def test_draw_invalid_boxes():
    img_tp = ((1, 1, 1), (1, 2, 3))
    img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
    img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
    boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
    boxes_wrong = torch.tensor([[10, 10, 4, 5], [30, 20, 10, 5]], dtype=torch.float)
    labels_wrong = ["one", "two"]
    colors_wrong = ["pink", "blue"]

    with pytest.raises(TypeError, match="Tensor expected"):
        utils.draw_bounding_boxes(img_tp, boxes)
    with pytest.raises(ValueError, match="Pass individual images, not batches"):
        utils.draw_bounding_boxes(img_wrong2, boxes)
    with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
        utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
    with pytest.raises(ValueError, match="Number of boxes"):
        utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
    with pytest.raises(ValueError, match="Number of colors"):
        utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
    with pytest.raises(ValueError, match="Boxes need to be in"):
        utils.draw_bounding_boxes(img_correct, boxes_wrong)


def test_draw_boxes_warning():
    img = torch.full((3, 100, 100), 255, dtype=torch.uint8)

    with pytest.warns(UserWarning, match=re.escape("Argument 'font_size' will be ignored since 'font' is not set.")):
        utils.draw_bounding_boxes(img, boxes, font_size=11)


def test_draw_no_boxes():
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    boxes = torch.full((0, 4), 0, dtype=torch.float)
    with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")):
        res = utils.draw_bounding_boxes(img, boxes)
        # Check that the function didn't change the image
        assert res.eq(img).all()


@pytest.mark.parametrize(
    "colors",
    [
        None,
        "blue",
        "#FF00FF",
        (1, 34, 122),
        ["red", "blue"],
        ["#FF00FF", (1, 34, 122)],
    ],
)
@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1))
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks(colors, alpha, device):
    """This test makes sure that masks draw their corresponding color where they should"""
    num_masks, h, w = 2, 100, 100
    dtype = torch.uint8
    img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
    masks = torch.zeros((num_masks, h, w), dtype=torch.bool, device=device)
    masks[0, 10:20, 10:20] = True
    masks[1, 15:25, 15:25] = True

    overlap = masks[0] & masks[1]

    out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
    assert out.dtype == dtype
    assert out is not img

    # Make sure the image didn't change where there's no mask
    masked_pixels = masks[0] | masks[1]
    assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels])

    if colors is None:
        colors = utils._generate_color_palette(num_masks)
    elif isinstance(colors, str) or isinstance(colors, tuple):
        colors = [colors]

    # Make sure each mask draws with its own color
    for mask, color in zip(masks, colors):
        if isinstance(color, str):
            color = ImageColor.getrgb(color)
        color = torch.tensor(color, dtype=dtype, device=device)

        if alpha == 1:
            assert (out[:, mask & ~overlap] == color[:, None]).all()
        elif alpha == 0:
            assert (out[:, mask & ~overlap] == img[:, mask & ~overlap]).all()

        interpolated_color = (img[:, mask & ~overlap] * (1 - alpha) + color[:, None] * alpha).to(dtype)
        torch.testing.assert_close(out[:, mask & ~overlap], interpolated_color, rtol=0.0, atol=1.0)

    interpolated_overlap = (img[:, overlap] * (1 - alpha)).to(dtype)
    torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0)


def test_draw_segmentation_masks_dtypes():
    num_masks, h, w = 2, 100, 100

    masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)

    img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
    out_uint8 = utils.draw_segmentation_masks(img_uint8, masks)

    assert img_uint8 is not out_uint8
    assert out_uint8.dtype == torch.uint8

    img_float = to_dtype(img_uint8, torch.float, scale=True)
    out_float = utils.draw_segmentation_masks(img_float, masks)

    assert img_float is not out_float
    assert out_float.is_floating_point()

    torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)


@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks_errors(device):
    h, w = 10, 10

    masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool, device=device)
    img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8, device=device)

    with pytest.raises(TypeError, match="The image must be a tensor"):
        utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
    with pytest.raises(ValueError, match="The image dtype must be"):
        img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64)
        utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks)
    with pytest.raises(ValueError, match="Pass individual images, not batches"):
        batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
        utils.draw_segmentation_masks(image=batch, masks=masks)
    with pytest.raises(ValueError, match="Pass an RGB image"):
        one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
        utils.draw_segmentation_masks(image=one_channel, masks=masks)
    with pytest.raises(ValueError, match="The masks must be of dtype bool"):
        masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype)
    with pytest.raises(ValueError, match="masks must be of shape"):
        masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
    with pytest.raises(ValueError, match="must have the same height and width"):
        masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
    with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
        utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
    with pytest.raises(ValueError, match="`colors` must be a tuple or a string, or a list thereof"):
        bad_colors = np.array(["red", "blue"])  # should be a list
        utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
    with pytest.raises(ValueError, match="If passed as tuple, colors should be an RGB triplet"):
        bad_colors = ("red", "blue")  # should be a list
        utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)


@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_no_segmention_mask(device):
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8, device=device)
    masks = torch.full((0, 100, 100), 0, dtype=torch.bool, device=device)
    with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
        res = utils.draw_segmentation_masks(img, masks)
        # Check that the function didn't change the image
        assert res.eq(img).all()


def test_draw_keypoints_vanilla():
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
    result = utils.draw_keypoints(
        img,
        keypoints,
        colors="red",
        connectivity=[
            (0, 1),
        ],
    )
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)
    # Check that keypoints are not modified inplace
    assert_equal(keypoints, keypoints_cp)
    # Check that image is not modified in place
    assert_equal(img, img_cp)


def test_draw_keypoins_K_equals_one():
    # Non-regression test for https://github.com/pytorch/vision/pull/8439
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    keypoints = torch.tensor([[[10, 10]]], dtype=torch.float)
    utils.draw_keypoints(img, keypoints)


@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
def test_draw_keypoints_colored(colors):
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
    result = utils.draw_keypoints(
        img,
        keypoints,
        colors=colors,
        connectivity=[
            (0, 1),
        ],
    )
    assert result.size(0) == 3
    assert_equal(keypoints, keypoints_cp)
    assert_equal(img, img_cp)


@pytest.mark.parametrize("connectivity", [[(0, 1)], [(0, 1), (1, 2)]])
@pytest.mark.parametrize(
    "vis",
    [
        torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.bool),
        torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.float).unsqueeze_(-1),
    ],
)
def test_draw_keypoints_visibility(connectivity, vis):
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()

    vis_cp = vis if vis is None else vis.clone()

    result = utils.draw_keypoints(
        image=img,
        keypoints=keypoints,
        connectivity=connectivity,
        colors="red",
        visibility=vis,
    )
    assert result.size(0) == 3
    assert_equal(keypoints, keypoints_cp)
    assert_equal(img, img_cp)

    # compare with a fakedata image
    # connect the key points 0 to 1 for both skeletons and do not show the other key points
    path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoints_visibility.png"
    )
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)

    if vis_cp is None:
        assert vis is None
    else:
        assert_equal(vis, vis_cp)
        assert vis.dtype == vis_cp.dtype


def test_draw_keypoints_visibility_default():
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()

    result = utils.draw_keypoints(
        image=img,
        keypoints=keypoints,
        connectivity=[(0, 1)],
        colors="red",
        visibility=None,
    )
    assert result.size(0) == 3
    assert_equal(keypoints, keypoints_cp)
    assert_equal(img, img_cp)

    # compare against fakedata image, which connects 0->1 for both key-point skeletons
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)


def test_draw_keypoints_dtypes():
    image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8)
    image_float = to_dtype(image_uint8, torch.float, scale=True)

    out_uint8 = utils.draw_keypoints(image_uint8, keypoints)
    out_float = utils.draw_keypoints(image_float, keypoints)

    assert out_uint8.dtype == torch.uint8
    assert out_uint8 is not image_uint8

    assert out_float.is_floating_point()
    assert out_float is not image_float

    torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)


def test_draw_keypoints_errors():
    h, w = 10, 10
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)

    with pytest.raises(TypeError, match="The image must be a tensor"):
        utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
    with pytest.raises(ValueError, match="The image dtype must be"):
        img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
        utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
    with pytest.raises(ValueError, match="Pass individual images, not batches"):
        batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
        utils.draw_keypoints(image=batch, keypoints=keypoints)
    with pytest.raises(ValueError, match="Pass an RGB image"):
        one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
        utils.draw_keypoints(image=one_channel, keypoints=keypoints)
    with pytest.raises(ValueError, match="keypoints must be of shape"):
        invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
        utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
    with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
        one_dim_visibility = torch.tensor([True, True, True], dtype=torch.bool)
        utils.draw_keypoints(image=img, keypoints=keypoints, visibility=one_dim_visibility)
    with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
        three_dim_visibility = torch.ones((2, 3, 4), dtype=torch.bool)
        utils.draw_keypoints(image=img, keypoints=keypoints, visibility=three_dim_visibility)
    with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
        vis_wrong_n = torch.ones((3, 3), dtype=torch.bool)
        utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_n)
    with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
        vis_wrong_k = torch.ones((2, 4), dtype=torch.bool)
        utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_k)


@pytest.mark.parametrize("batch", (True, False))
def test_flow_to_image(batch):
    h, w = 100, 100
    flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    flow = torch.stack(flow[::-1], dim=0).float()
    flow[0] -= h / 2
    flow[1] -= w / 2

    if batch:
        flow = torch.stack([flow, flow])

    img = utils.flow_to_image(flow)
    assert img.shape == (2, 3, h, w) if batch else (3, h, w)

    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
    expected_img = torch.load(path, map_location="cpu", weights_only=True)

    if batch:
        expected_img = torch.stack([expected_img, expected_img])

    assert_equal(expected_img, img)


@pytest.mark.parametrize(
    "input_flow, match",
    (
        (torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
    ),
)
def test_flow_to_image_errors(input_flow, match):
    with pytest.raises(ValueError, match=match):
        utils.flow_to_image(flow=input_flow)


if __name__ == "__main__":
    pytest.main([__file__])
