1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
|
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torchaudio
import torchvision
class AVSRDataLoader:
def __init__(self, modality, detector="retinaface", resize=None):
self.modality = modality
if modality == "video":
if detector == "retinaface":
from detectors.retinaface.detector import LandmarksDetector
from detectors.retinaface.video_process import VideoProcess
self.landmarks_detector = LandmarksDetector(device="cuda:0")
self.video_process = VideoProcess(resize=resize)
if detector == "mediapipe":
from detectors.mediapipe.detector import LandmarksDetector
from detectors.mediapipe.video_process import VideoProcess
self.landmarks_detector = LandmarksDetector()
self.video_process = VideoProcess(resize=resize)
def load_data(self, data_filename, transform=True):
if self.modality == "audio":
audio, sample_rate = self.load_audio(data_filename)
audio = self.audio_process(audio, sample_rate)
return audio
if self.modality == "video":
video = self.load_video(data_filename)
landmarks = self.landmarks_detector(video)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
return video
def load_audio(self, data_filename):
waveform, sample_rate = torchaudio.load(data_filename, normalize=True)
return waveform, sample_rate
def load_video(self, data_filename):
return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy()
def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
if sample_rate != target_sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
waveform = torch.mean(waveform, dim=0, keepdim=True)
return waveform
|