1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
import torch
import transforms as T
class OpticalFlowPresetEval(torch.nn.Module):
def __init__(self):
super().__init__()
self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.ValidateModelInput(),
]
)
def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)
class OpticalFlowPresetTrain(torch.nn.Module):
def __init__(
self,
*,
# RandomResizeAndCrop params
crop_size,
min_scale=-0.2,
max_scale=0.5,
stretch_prob=0.8,
# AsymmetricColorJitter params
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.5 / 3.14,
# Random[H,V]Flip params
asymmetric_jitter_prob=0.2,
do_flip=True,
):
super().__init__()
transforms = [
T.PILToTensor(),
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.RandomResizeAndCrop(
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
),
]
if do_flip:
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]
transforms += [
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.RandomErasing(max_erase=2),
T.MakeValidFlowMask(),
T.ValidateModelInput(),
]
self.transforms = T.Compose(transforms)
def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)
|