File: plot_keypoints_transforms.py

package info (click to toggle)
pytorch-vision 0.24.1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 21,844 kB
  • sloc: python: 70,433; cpp: 11,502; ansic: 2,588; java: 550; sh: 317; xml: 79; objc: 56; makefile: 33
file content (116 lines) | stat: -rw-r--r-- 3,533 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
===============================================================
Transforms on KeyPoints
===============================================================

This example illustrates how to define and use keypoints.
For this tutorial, we use this picture of a ceramic figure from the pre-columbian period.
The image is specified "public domain" (https://www.metmuseum.org/art/collection/search/502727).

.. note::
    Support for keypoints was released in TorchVision 0.23 and is
    currently a BETA feature. We don't expect the API to change, but there may
    be some rare edge-cases. If you find any issues, please report them on
    our bug tracker: https://github.com/pytorch/vision/issues?q=is:open+is:issue

First, a bit of setup code:
"""

# %%
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt


import torch
from torchvision.tv_tensors import KeyPoints
from torchvision.transforms import v2
from helpers import plot

plt.rcParams["figure.figsize"] = [10, 5]
plt.rcParams["savefig.bbox"] = "tight"

# if you change the seed, make sure that the transformed output
# still make sense
torch.manual_seed(0)

# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
orig_img = Image.open(Path('../assets') / 'pottery.jpg')

# %%
# Creating KeyPoints
# -------------------------------
# Key points are created by instantiating the
# :class:`~torchvision.tv_tensors.KeyPoints` class.


orig_pts = KeyPoints(
    [
        [
            [445, 700],  # nose
            [320, 660],
            [370, 660],
            [420, 660],  # left eye
            [300, 620],
            [420, 620],  # left eyebrow
            [475, 665],
            [515, 665],
            [555, 655],  # right eye
            [460, 625],
            [560, 600],  # right eyebrow
            [370, 780],
            [450, 760],
            [540, 780],
            [450, 820],  # mouth
        ],
    ],
    canvas_size=(orig_img.size[1], orig_img.size[0]),
)

plot([(orig_img, orig_pts)])

# %%
# Transforms illustrations
# ------------------------
#
# Using :class:`~torchvision.transforms.RandomRotation`:
rotater = v2.RandomRotation(degrees=(0, 180), expand=True)
rotated_imgs = [rotater((orig_img, orig_pts)) for _ in range(4)]
plot([(orig_img, orig_pts)] + rotated_imgs)

# %%
# Using :class:`~torchvision.transforms.Pad`:
padded_imgs_and_points = [
    v2.Pad(padding=padding)(orig_img, orig_pts)
    for padding in (30, 50, 100, 200)
]
plot([(orig_img, orig_pts)] + padded_imgs_and_points)

# %%
# Using :class:`~torchvision.transforms.Resize`:
resized_imgs = [
    v2.Resize(size=size)(orig_img, orig_pts)
    for size in (300, 500, 1000, orig_img.size)
]
plot([(orig_img, orig_pts)] + resized_imgs)

# %%
# Using :class:`~torchvision.transforms.RandomPerspective`:
perspective_transformer = v2.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img, orig_pts) for _ in range(4)]
plot([(orig_img, orig_pts)] + perspective_imgs)

# %%
# Using :class:`~torchvision.transforms.CenterCrop`:
center_crops_and_points = [
    v2.CenterCrop(size=size)(orig_img, orig_pts)
    for size in (300, 500, 1000, orig_img.size)
]
plot([(orig_img, orig_pts)] + center_crops_and_points)

# %%
# Using :class:`~torchvision.transforms.RandomRotation`:
rotater = v2.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater((orig_img, orig_pts)) for _ in range(4)]
plot([(orig_img, orig_pts)] + rotated_imgs)