import collections
import os
import urllib

import pytest
import torch
import torchvision
from pytest import approx
from torchvision.datasets.utils import download_url
from torchvision.io import _HAS_VIDEO_OPT, VideoReader


try:
    import av

    # Do a version test too
    torchvision.io.video._check_av_available()
except ImportError:
    av = None


VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")

CheckerConfig = ["duration", "video_fps", "audio_sample_rate"]
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))


def fate(name, path="."):
    """Download and return a path to a sample from the FFmpeg test suite.
    See the `FFmpeg Automated Test Environment <https://www.ffmpeg.org/fate.html>`_
    """

    file_name = name.split("/")[1]
    download_url("http://fate.ffmpeg.org/fate-suite/" + name, path, file_name)
    return os.path.join(path, file_name)


test_videos = {
    "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
    "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
        duration=2.0, video_fps=30.0, audio_sample_rate=None
    ),
    "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
    "v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
    "v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
    "R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100),
    "SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
    "WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
}


@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
class TestVideoApi:
    @pytest.mark.skipif(av is None, reason="PyAV unavailable")
    @pytest.mark.parametrize("test_video", test_videos.keys())
    def test_frame_reading(self, test_video):
        full_path = os.path.join(VIDEO_DIR, test_video)
        with av.open(full_path) as av_reader:
            if av_reader.streams.video:
                av_frames, vr_frames = [], []
                av_pts, vr_pts = [], []
                # get av frames
                for av_frame in av_reader.decode(av_reader.streams.video[0]):
                    av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1))
                    av_pts.append(av_frame.pts * av_frame.time_base)

                # get vr frames
                video_reader = VideoReader(full_path, "video")
                for vr_frame in video_reader:
                    vr_frames.append(vr_frame["data"])
                    vr_pts.append(vr_frame["pts"])

                # same number of frames
                assert len(vr_frames) == len(av_frames)
                assert len(vr_pts) == len(av_pts)

                # compare the frames and ptss
                for i in range(len(vr_frames)):
                    assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
                    mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
                    # on average the difference is very small and caused
                    # by decoding (around 1%)
                    # TODO: asses empirically how to set this? atm it's 1%
                    # averaged over all frames
                    assert mean_delta.item() < 2.55

                del vr_frames, av_frames, vr_pts, av_pts

        # test audio reading compared to PYAV
        with av.open(full_path) as av_reader:
            if av_reader.streams.audio:
                av_frames, vr_frames = [], []
                av_pts, vr_pts = [], []
                # get av frames
                for av_frame in av_reader.decode(av_reader.streams.audio[0]):
                    av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0))
                    av_pts.append(av_frame.pts * av_frame.time_base)
                av_reader.close()

                # get vr frames
                video_reader = VideoReader(full_path, "audio")
                for vr_frame in video_reader:
                    vr_frames.append(vr_frame["data"])
                    vr_pts.append(vr_frame["pts"])

                # same number of frames
                assert len(vr_frames) == len(av_frames)
                assert len(vr_pts) == len(av_pts)

                # compare the frames and ptss
                for i in range(len(vr_frames)):
                    assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
                    max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float()))
                    # we assure that there is never more than 1% difference in signal
                    assert max_delta.item() < 0.001

    @pytest.mark.parametrize("test_video,config", test_videos.items())
    def test_metadata(self, test_video, config):
        """
        Test that the metadata returned via pyav corresponds to the one returned
        by the new video decoder API
        """
        full_path = os.path.join(VIDEO_DIR, test_video)
        reader = VideoReader(full_path, "video")
        reader_md = reader.get_metadata()
        assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
        assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)

    @pytest.mark.parametrize("test_video", test_videos.keys())
    def test_seek_start(self, test_video):
        full_path = os.path.join(VIDEO_DIR, test_video)
        video_reader = VideoReader(full_path, "video")
        num_frames = 0
        for _ in video_reader:
            num_frames += 1

        # now seek the container to 0 and do it again
        # It's often that starting seek can be inprecise
        # this way and it doesn't start at 0
        video_reader.seek(0)
        start_num_frames = 0
        for _ in video_reader:
            start_num_frames += 1

        assert start_num_frames == num_frames

        # now seek the container to < 0 to check for unexpected behaviour
        video_reader.seek(-1)
        start_num_frames = 0
        for _ in video_reader:
            start_num_frames += 1

        assert start_num_frames == num_frames

    @pytest.mark.parametrize("test_video", test_videos.keys())
    def test_accurateseek_middle(self, test_video):
        full_path = os.path.join(VIDEO_DIR, test_video)
        stream = "video"
        video_reader = VideoReader(full_path, stream)
        md = video_reader.get_metadata()
        duration = md[stream]["duration"][0]
        if duration is not None:
            num_frames = 0
            for _ in video_reader:
                num_frames += 1

            video_reader.seek(duration / 2)
            middle_num_frames = 0
            for _ in video_reader:
                middle_num_frames += 1

            assert middle_num_frames < num_frames
            assert middle_num_frames == approx(num_frames // 2, abs=1)

            video_reader.seek(duration / 2)
            frame = next(video_reader)
            lb = duration / 2 - 1 / md[stream]["fps"][0]
            ub = duration / 2 + 1 / md[stream]["fps"][0]
            assert (lb <= frame["pts"]) and (ub >= frame["pts"])

    def test_fate_suite(self):
        # TODO: remove the try-except statement once the connectivity issues are resolved
        try:
            video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR)
        except (urllib.error.URLError, ConnectionError) as error:
            pytest.skip(f"Skipping due to connectivity issues: {error}")
        vr = VideoReader(video_path)
        metadata = vr.get_metadata()

        assert metadata["subtitles"]["duration"] is not None
        os.remove(video_path)

    @pytest.mark.skipif(av is None, reason="PyAV unavailable")
    @pytest.mark.parametrize("test_video,config", test_videos.items())
    def test_keyframe_reading(self, test_video, config):
        full_path = os.path.join(VIDEO_DIR, test_video)

        av_reader = av.open(full_path)
        # reduce streams to only keyframes
        av_stream = av_reader.streams.video[0]
        av_stream.codec_context.skip_frame = "NONKEY"

        av_keyframes = []
        vr_keyframes = []
        if av_reader.streams.video:

            # get all keyframes using pyav. Then, seek randomly into video reader
            # and assert that all the returned values are in AV_KEYFRAMES

            for av_frame in av_reader.decode(av_stream):
                av_keyframes.append(float(av_frame.pts * av_frame.time_base))

        if len(av_keyframes) > 1:
            video_reader = VideoReader(full_path, "video")
            for i in range(1, len(av_keyframes)):
                seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
                data = next(video_reader.seek(seek_val, True))
                vr_keyframes.append(data["pts"])

            data = next(video_reader.seek(config.duration, True))
            vr_keyframes.append(data["pts"])

            assert len(av_keyframes) == len(vr_keyframes)
            # NOTE: this video gets different keyframe with different
            # loaders (0.333 pyav, 0.666 for us)
            if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
                for i in range(len(av_keyframes)):
                    assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)


if __name__ == "__main__":
    pytest.main([__file__])
