File: whisper.py

package info (click to toggle)
python-pysubs2 1.8.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,840 kB
  • sloc: python: 4,016; makefile: 163
file content (106 lines) | stat: -rw-r--r-- 3,344 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Support for the OpenAI Whisper speech recognition library.

See https://github.com/openai/whisper

"""
from .base import FormatBase
from ..ssaevent import SSAEvent
from ..ssafile import SSAFile
from ..time import make_time, timestamp_to_ms
from typing import Union, List, Dict, Any, Optional, Sequence, TextIO
import re


def load_from_whisper(result_or_segments: Union[Dict[str, Any], List[Dict[str, Any]]]) -> SSAFile:
    """
    Load subtitle file from OpenAI Whisper transcript

    Example:
        >>> import whisper
        >>> import pysubs2
        >>> model = whisper.load_model("base")
        >>> result = model.transcribe("audio.mp3")
        >>> subs = pysubs2.load_from_whisper(result)
        >>> subs.save("audio.ass")

    See also:
        https://github.com/openai/whisper

    Arguments:
        result_or_segments: Either a dict with a ``"segments"`` key
            that holds a list of segment dicts, or the segment list-of-dicts.
            Each segment is a dict with keys ``"start"``, ``"end"`` (float, timestamps
            in seconds) and ``"text"`` (str with caption text).

    Returns:
        :class:`pysubs2.SSAFile`

    """
    if isinstance(result_or_segments, dict):
        segments = result_or_segments["segments"]
    elif isinstance(result_or_segments, list):
        segments = result_or_segments
    else:
        raise TypeError("Expected either a dict with 'segments' key, or a list of segments")

    subs = SSAFile()
    for segment in segments:
        event = SSAEvent(start=make_time(s=segment["start"]), end=make_time(s=segment["end"]))
        event.plaintext = segment["text"].strip()
        subs.append(event)

    return subs


class WhisperJAXFormat(FormatBase):
    """
    Parser for Whisper JAX transcription, one event per line, eg. ``[00:02.880 -> 00:07.240]  transcribed text``

    """
    TIMESTAMP = re.compile(r"(?:(\d{1,2}):)?(\d{2}):(\d{2})[.](\d{3})")
    LINE = re.compile(r"\[([^]]+) -> ([^]]+)] (.*)")

    @classmethod
    def guess_format(cls, text: str) -> Optional[str]:
        """See :meth:`pysubs2.formats.FormatBase.guess_format()`"""
        for line in text.lstrip().splitlines():
            if cls.parse_line(line):
                return "whisper_jax"
            else:
                return None

        return None

    @classmethod
    def parse_line(cls, line: str) -> Optional[SSAEvent]:
        m = cls.LINE.match(line)
        if m is None:
            return None

        m_start = cls.TIMESTAMP.fullmatch(m.group(1))
        m_end = cls.TIMESTAMP.fullmatch(m.group(2))
        text = m.group(3).strip()

        if m_start is None or m_end is None:
            return None

        start_ms = cls.timestamp_to_ms(m_start.groups())
        end_ms = cls.timestamp_to_ms(m_end.groups())

        return SSAEvent(start=start_ms, end=end_ms, text=text)

    @staticmethod
    def timestamp_to_ms(groups: Sequence[str]) -> int:
        return timestamp_to_ms([x or "0" for x in groups])

    @classmethod
    def from_file(cls, subs: "SSAFile", fp: TextIO, format_: str, **kwargs: Any) -> None:
        """
        See :meth:`pysubs2.formats.FormatBase.from_file()`
        """
        for line in fp.readlines():
            line = line.strip()
            e = cls.parse_line(line)
            if e is not None:
                subs.append(e)