File: transforms.py

package info (click to toggle)
pytorch-vision 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 15,188 kB
  • sloc: python: 49,008; cpp: 10,019; sh: 610; java: 550; xml: 79; objc: 56; makefile: 32
file content (267 lines) | stat: -rw-r--r-- 11,650 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F


class ValidateModelInput(torch.nn.Module):
    # Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
    def forward(self, img1, img2, flow, valid_flow_mask):

        if not all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None):
            raise TypeError("This method expects all input arguments to be of type torch.Tensor.")
        if not all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None):
            raise TypeError("This method expects the tensors img1, img2 and flow of be of dtype torch.float32.")

        if img1.shape != img2.shape:
            raise ValueError("img1 and img2 should have the same shape.")
        h, w = img1.shape[-2:]
        if flow is not None and flow.shape != (2, h, w):
            raise ValueError(f"flow.shape should be (2, {h}, {w}) instead of {flow.shape}")
        if valid_flow_mask is not None:
            if valid_flow_mask.shape != (h, w):
                raise ValueError(f"valid_flow_mask.shape should be ({h}, {w}) instead of {valid_flow_mask.shape}")
            if valid_flow_mask.dtype != torch.bool:
                raise TypeError("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}")

        return img1, img2, flow, valid_flow_mask


class MakeValidFlowMask(torch.nn.Module):
    # This transform generates a valid_flow_mask if it doesn't exist.
    # The flow is considered valid if ||flow||_inf < threshold
    # This is a noop for Kitti and HD1K which already come with a built-in flow mask.
    def __init__(self, threshold=1000):
        super().__init__()
        self.threshold = threshold

    def forward(self, img1, img2, flow, valid_flow_mask):
        if flow is not None and valid_flow_mask is None:
            valid_flow_mask = (flow.abs() < self.threshold).all(axis=0)
        return img1, img2, flow, valid_flow_mask


class ConvertImageDtype(torch.nn.Module):
    def __init__(self, dtype):
        super().__init__()
        self.dtype = dtype

    def forward(self, img1, img2, flow, valid_flow_mask):
        img1 = F.convert_image_dtype(img1, dtype=self.dtype)
        img2 = F.convert_image_dtype(img2, dtype=self.dtype)

        img1 = img1.contiguous()
        img2 = img2.contiguous()

        return img1, img2, flow, valid_flow_mask


class Normalize(torch.nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean
        self.std = std

    def forward(self, img1, img2, flow, valid_flow_mask):
        img1 = F.normalize(img1, mean=self.mean, std=self.std)
        img2 = F.normalize(img2, mean=self.mean, std=self.std)

        return img1, img2, flow, valid_flow_mask


class PILToTensor(torch.nn.Module):
    # Converts all inputs to tensors
    # Technically the flow and the valid mask are numpy arrays, not PIL images, but we keep that naming
    # for consistency with the rest, e.g. the segmentation reference.
    def forward(self, img1, img2, flow, valid_flow_mask):
        img1 = F.pil_to_tensor(img1)
        img2 = F.pil_to_tensor(img2)
        if flow is not None:
            flow = torch.from_numpy(flow)
        if valid_flow_mask is not None:
            valid_flow_mask = torch.from_numpy(valid_flow_mask)

        return img1, img2, flow, valid_flow_mask


class AsymmetricColorJitter(T.ColorJitter):
    # p determines the proba of doing asymmertric vs symmetric color jittering
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.2):
        super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
        self.p = p

    def forward(self, img1, img2, flow, valid_flow_mask):

        if torch.rand(1) < self.p:
            # asymmetric: different transform for img1 and img2
            img1 = super().forward(img1)
            img2 = super().forward(img2)
        else:
            # symmetric: same transform for img1 and img2
            batch = torch.stack([img1, img2])
            batch = super().forward(batch)
            img1, img2 = batch[0], batch[1]

        return img1, img2, flow, valid_flow_mask


class RandomErasing(T.RandomErasing):
    # This only erases img2, and with an extra max_erase param
    # This max_erase is needed because in the RAFT training ref does:
    # 0 erasing with .5 proba
    # 1 erase with .25 proba
    # 2 erase with .25 proba
    # and there's no accurate way to achieve this otherwise.
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
        super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
        self.max_erase = max_erase
        if self.max_erase <= 0:
            raise ValueError("max_raise should be greater than 0")

    def forward(self, img1, img2, flow, valid_flow_mask):
        if torch.rand(1) > self.p:
            return img1, img2, flow, valid_flow_mask

        for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
            x, y, h, w, v = self.get_params(img2, scale=self.scale, ratio=self.ratio, value=[self.value])
            img2 = F.erase(img2, x, y, h, w, v, self.inplace)

        return img1, img2, flow, valid_flow_mask


class RandomHorizontalFlip(T.RandomHorizontalFlip):
    def forward(self, img1, img2, flow, valid_flow_mask):
        if torch.rand(1) > self.p:
            return img1, img2, flow, valid_flow_mask

        img1 = F.hflip(img1)
        img2 = F.hflip(img2)
        flow = F.hflip(flow) * torch.tensor([-1, 1])[:, None, None]
        if valid_flow_mask is not None:
            valid_flow_mask = F.hflip(valid_flow_mask)
        return img1, img2, flow, valid_flow_mask


class RandomVerticalFlip(T.RandomVerticalFlip):
    def forward(self, img1, img2, flow, valid_flow_mask):
        if torch.rand(1) > self.p:
            return img1, img2, flow, valid_flow_mask

        img1 = F.vflip(img1)
        img2 = F.vflip(img2)
        flow = F.vflip(flow) * torch.tensor([1, -1])[:, None, None]
        if valid_flow_mask is not None:
            valid_flow_mask = F.vflip(valid_flow_mask)
        return img1, img2, flow, valid_flow_mask


class RandomResizeAndCrop(torch.nn.Module):
    # This transform will resize the input with a given proba, and then crop it.
    # These are the reversed operations of the built-in RandomResizedCrop,
    # although the order of the operations doesn't matter too much: resizing a
    # crop would give the same result as cropping a resized image, up to
    # interpolation artifact at the borders of the output.
    #
    # The reason we don't rely on RandomResizedCrop is because of a significant
    # difference in the parametrization of both transforms, in particular,
    # because of the way the random parameters are sampled in both transforms,
    # which leads to fairly different resuts (and different epe). For more details see
    # https://github.com/pytorch/vision/pull/5026/files#r762932579
    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, stretch_prob=0.8):
        super().__init__()
        self.crop_size = crop_size
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.stretch_prob = stretch_prob
        self.resize_prob = 0.8
        self.max_stretch = 0.2

    def forward(self, img1, img2, flow, valid_flow_mask):
        # randomly sample scale
        h, w = img1.shape[-2:]
        # Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
        # It shouldn't matter much
        min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)

        scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
        scale_x = scale
        scale_y = scale
        if torch.rand(1) < self.stretch_prob:
            scale_x *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
            scale_y *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()

        scale_x = max(scale_x, min_scale)
        scale_y = max(scale_y, min_scale)

        new_h, new_w = round(h * scale_y), round(w * scale_x)

        if torch.rand(1).item() < self.resize_prob:
            # rescale the images
            img1 = F.resize(img1, size=(new_h, new_w))
            img2 = F.resize(img2, size=(new_h, new_w))
            if valid_flow_mask is None:
                flow = F.resize(flow, size=(new_h, new_w))
                flow = flow * torch.tensor([scale_x, scale_y])[:, None, None]
            else:
                flow, valid_flow_mask = self._resize_sparse_flow(
                    flow, valid_flow_mask, scale_x=scale_x, scale_y=scale_y
                )

        # Note: For sparse datasets (Kitti), the original code uses a "margin"
        # See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
        # We don't, not sure it matters much
        y0 = torch.randint(0, img1.shape[1] - self.crop_size[0], size=(1,)).item()
        x0 = torch.randint(0, img1.shape[2] - self.crop_size[1], size=(1,)).item()

        img1 = F.crop(img1, y0, x0, self.crop_size[0], self.crop_size[1])
        img2 = F.crop(img2, y0, x0, self.crop_size[0], self.crop_size[1])
        flow = F.crop(flow, y0, x0, self.crop_size[0], self.crop_size[1])
        if valid_flow_mask is not None:
            valid_flow_mask = F.crop(valid_flow_mask, y0, x0, self.crop_size[0], self.crop_size[1])

        return img1, img2, flow, valid_flow_mask

    def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0):
        # This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
        # There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
        # So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4

        h, w = flow.shape[-2:]

        h_new = int(round(h * scale_y))
        w_new = int(round(w * scale_x))
        flow_new = torch.zeros(size=[2, h_new, w_new], dtype=flow.dtype)
        valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)

        jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")

        ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]

        ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
        jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)

        within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)

        ii_valid = ii_valid[within_bounds_mask]
        jj_valid = jj_valid[within_bounds_mask]
        ii_valid_new = ii_valid_new[within_bounds_mask]
        jj_valid_new = jj_valid_new[within_bounds_mask]

        valid_flow_new = flow[:, ii_valid, jj_valid]
        valid_flow_new[0] *= scale_x
        valid_flow_new[1] *= scale_y

        flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
        valid_new[ii_valid_new, jj_valid_new] = 1

        return flow_new, valid_new


class Compose(torch.nn.Module):
    def __init__(self, transforms):
        super().__init__()
        self.transforms = transforms

    def forward(self, img1, img2, flow, valid_flow_mask):
        for t in self.transforms:
            img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask)
        return img1, img2, flow, valid_flow_mask