1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
|
from typing import List, Tuple
import torch
from torch import Tensor
class ImageList:
"""
Structure that holds a list of images (of possibly
varying sizes) as a single tensor.
This works by padding the images to the same size,
and storing in a field the original sizes of each image
Args:
tensors (tensor): Tensor containing images.
image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
"""
def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
self.tensors = tensors
self.image_sizes = image_sizes
def to(self, device: torch.device) -> "ImageList":
cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes)
|