1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
|
import collections
import os
import urllib
import pytest
import torch
import torchvision
from pytest import approx
from torchvision.datasets.utils import download_url
from torchvision.io import _HAS_VIDEO_OPT, VideoReader
try:
import av
# Do a version test too
torchvision.io.video._check_av_available()
except ImportError:
av = None
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = ["duration", "video_fps", "audio_sample_rate"]
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
def fate(name, path="."):
"""Download and return a path to a sample from the FFmpeg test suite.
See the `FFmpeg Automated Test Environment <https://www.ffmpeg.org/fate.html>`_
"""
file_name = name.split("/")[1]
download_url("http://fate.ffmpeg.org/fate-suite/" + name, path, file_name)
return os.path.join(path, file_name)
test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0, video_fps=30.0, audio_sample_rate=None
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
"R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100),
"SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
"WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
}
@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_frame_reading(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
with av.open(full_path) as av_reader:
if av_reader.streams.video:
av_frames, vr_frames = [], []
av_pts, vr_pts = [], []
# get av frames
for av_frame in av_reader.decode(av_reader.streams.video[0]):
av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1))
av_pts.append(av_frame.pts * av_frame.time_base)
# get vr frames
video_reader = VideoReader(full_path, "video")
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# same number of frames
assert len(vr_frames) == len(av_frames)
assert len(vr_pts) == len(av_pts)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55
del vr_frames, av_frames, vr_pts, av_pts
# test audio reading compared to PYAV
with av.open(full_path) as av_reader:
if av_reader.streams.audio:
av_frames, vr_frames = [], []
av_pts, vr_pts = [], []
# get av frames
for av_frame in av_reader.decode(av_reader.streams.audio[0]):
av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0))
av_pts.append(av_frame.pts * av_frame.time_base)
av_reader.close()
# get vr frames
video_reader = VideoReader(full_path, "audio")
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# same number of frames
assert len(vr_frames) == len(av_frames)
assert len(vr_pts) == len(av_pts)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# we assure that there is never more than 1% difference in signal
assert max_delta.item() < 0.001
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_metadata(self, test_video, config):
"""
Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API
"""
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata()
assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_seek_start(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
video_reader = VideoReader(full_path, "video")
num_frames = 0
for _ in video_reader:
num_frames += 1
# now seek the container to 0 and do it again
# It's often that starting seek can be inprecise
# this way and it doesn't start at 0
video_reader.seek(0)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
# now seek the container to < 0 to check for unexpected behaviour
video_reader.seek(-1)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_accurateseek_middle(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
stream = "video"
video_reader = VideoReader(full_path, stream)
md = video_reader.get_metadata()
duration = md[stream]["duration"][0]
if duration is not None:
num_frames = 0
for _ in video_reader:
num_frames += 1
video_reader.seek(duration / 2)
middle_num_frames = 0
for _ in video_reader:
middle_num_frames += 1
assert middle_num_frames < num_frames
assert middle_num_frames == approx(num_frames // 2, abs=1)
video_reader.seek(duration / 2)
frame = next(video_reader)
lb = duration / 2 - 1 / md[stream]["fps"][0]
ub = duration / 2 + 1 / md[stream]["fps"][0]
assert (lb <= frame["pts"]) and (ub >= frame["pts"])
def test_fate_suite(self):
# TODO: remove the try-except statement once the connectivity issues are resolved
try:
video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR)
except (urllib.error.URLError, ConnectionError) as error:
pytest.skip(f"Skipping due to connectivity issues: {error}")
vr = VideoReader(video_path)
metadata = vr.get_metadata()
assert metadata["subtitles"]["duration"] is not None
os.remove(video_path)
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_keyframe_reading(self, test_video, config):
full_path = os.path.join(VIDEO_DIR, test_video)
av_reader = av.open(full_path)
# reduce streams to only keyframes
av_stream = av_reader.streams.video[0]
av_stream.codec_context.skip_frame = "NONKEY"
av_keyframes = []
vr_keyframes = []
if av_reader.streams.video:
# get all keyframes using pyav. Then, seek randomly into video reader
# and assert that all the returned values are in AV_KEYFRAMES
for av_frame in av_reader.decode(av_stream):
av_keyframes.append(float(av_frame.pts * av_frame.time_base))
if len(av_keyframes) > 1:
video_reader = VideoReader(full_path, "video")
for i in range(1, len(av_keyframes)):
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
data = next(video_reader.seek(seek_val, True))
vr_keyframes.append(data["pts"])
data = next(video_reader.seek(config.duration, True))
vr_keyframes.append(data["pts"])
assert len(av_keyframes) == len(vr_keyframes)
# NOTE: this video gets different keyframe with different
# loaders (0.333 pyav, 0.666 for us)
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
for i in range(len(av_keyframes)):
assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)
if __name__ == "__main__":
pytest.main([__file__])
|