1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
import numpy as np
import torch
from PIL import Image
try:
from image_dataset_viz import render_datapoint
except ImportError:
raise ModuleNotFoundError(
"Please install image-dataset-viz via pip install --upgrade git+https://github.com/vfdev-5/ImageDatasetViz.git"
)
def _getvocpallete(num_cls):
n = num_cls
pallete = [0] * (n * 3)
for j in range(0, n):
lab = j
pallete[j * 3 + 0] = 0
pallete[j * 3 + 1] = 0
pallete[j * 3 + 2] = 0
i = 0
while lab > 0:
pallete[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
pallete[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
pallete[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
i = i + 1
lab >>= 3
return pallete
vocpallete = _getvocpallete(256)
def render_mask(mask):
if isinstance(mask, np.ndarray):
mask = Image.fromarray(mask)
mask.putpalette(vocpallete)
mask = mask.convert(mode="RGB")
return mask
def tensor_to_rgb(t):
img = t.cpu().numpy().transpose((1, 2, 0))
return img.astype(np.uint8)
def make_grid(batch_img, batch_mask, img_denormalize_fn, batch_gt_mask=None):
"""Create a grid from batch image and mask as
img1 | img2 | img3 | img4 | ...
i+m1 | i+m2 | i+m3 | i+m4 | ...
mask1 | mask2 | mask3 | mask4 | ...
i+M1 | i+M2 | i+M3 | i+M4 | ...
Mask1 | Mask2 | Mask3 | Mask4 | ...
i+m = image + mask blended with alpha=0.4
- maskN is predicted mask
- MaskN is ground-truth mask if given
Args:
batch_img (torch.Tensor) batch of images of any type
batch_mask (torch.Tensor) batch of masks
img_denormalize_fn (Callable): function to denormalize batch of images
batch_gt_mask (torch.Tensor, optional): batch of ground truth masks.
"""
assert isinstance(batch_img, torch.Tensor) and isinstance(batch_mask, torch.Tensor)
assert len(batch_img) == len(batch_mask)
if batch_gt_mask is not None:
assert isinstance(batch_gt_mask, torch.Tensor)
assert len(batch_mask) == len(batch_gt_mask)
b = batch_img.shape[0]
h, w = batch_img.shape[2:]
le = 3 if batch_gt_mask is None else 3 + 2
out_image = np.zeros((h * le, w * b, 3), dtype="uint8")
for i in range(b):
img = batch_img[i]
mask = batch_mask[i]
img = img_denormalize_fn(img)
img = tensor_to_rgb(img)
mask = mask.cpu().numpy()
mask = render_mask(mask)
out_image[0:h, i * w : (i + 1) * w, :] = img
out_image[1 * h : 2 * h, i * w : (i + 1) * w, :] = render_datapoint(img, mask, blend_alpha=0.4)
out_image[2 * h : 3 * h, i * w : (i + 1) * w, :] = mask
if batch_gt_mask is not None:
gt_mask = batch_gt_mask[i]
gt_mask = gt_mask.cpu().numpy()
gt_mask = render_mask(gt_mask)
out_image[3 * h : 4 * h, i * w : (i + 1) * w, :] = render_datapoint(img, gt_mask, blend_alpha=0.4)
out_image[4 * h : 5 * h, i * w : (i + 1) * w, :] = gt_mask
return out_image
def predictions_gt_images_handler(img_denormalize_fn, n_images=None, another_engine=None, prefix_tag=None):
def wrapper(engine, logger, event_name):
batch = engine.state.batch
output = engine.state.output
x = batch["image"]
y = batch["mask"]
y_pred = output[0]
if y.shape == y_pred.shape and y.ndim == 4:
# Case of y of shape (B, C, H, W)
y = torch.argmax(y, dim=1)
y_pred = torch.argmax(y_pred, dim=1).byte()
if n_images is not None:
x = x[:n_images, ...]
y = y[:n_images, ...]
y_pred = y_pred[:n_images, ...]
grid_pred_gt = make_grid(x, y_pred, img_denormalize_fn, batch_gt_mask=y)
state = engine.state if another_engine is None else another_engine.state
global_step = state.epoch
tag = "predictions_with_gt"
if prefix_tag is not None:
tag = f"{prefix_tag}: {tag} - epoch={global_step}"
logger.writer.add_image(tag=tag, img_tensor=grid_pred_gt, global_step=global_step, dataformats="HWC")
return wrapper
|