File: text_preprocessing.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (173 lines) | stat: -rw-r--r-- 7,045 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# *****************************************************************************
# Copyright (c) 2017 Keith Ito
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# *****************************************************************************
"""
Modified from https://github.com/keithito/tacotron
"""

import re
from typing import List, Optional, Union

from torchaudio.datasets import CMUDict
from unidecode import unidecode

from .numbers import normalize_numbers


# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")

# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
    (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
    for x in [
        ("mrs", "misess"),
        ("mr", "mister"),
        ("dr", "doctor"),
        ("st", "saint"),
        ("co", "company"),
        ("jr", "junior"),
        ("maj", "major"),
        ("gen", "general"),
        ("drs", "doctors"),
        ("rev", "reverend"),
        ("lt", "lieutenant"),
        ("hon", "honorable"),
        ("sgt", "sergeant"),
        ("capt", "captain"),
        ("esq", "esquire"),
        ("ltd", "limited"),
        ("col", "colonel"),
        ("ft", "fort"),
    ]
]

_pad = "_"
_punctuation = "!'(),.:;? "
_special = "-"
_letters = "abcdefghijklmnopqrstuvwxyz"

symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
_phonemizer = None


available_symbol_set = {"english_characters", "english_phonemes"}
available_phonemizers = {"DeepPhonemizer"}


def get_symbol_list(symbol_list: str = "english_characters", cmudict_root: Optional[str] = "./") -> List[str]:
    if symbol_list == "english_characters":
        return [_pad] + list(_special) + list(_punctuation) + list(_letters)
    elif symbol_list == "english_phonemes":
        return [_pad] + list(_special) + list(_punctuation) + CMUDict(cmudict_root).symbols
    else:
        raise ValueError(
            f"The `symbol_list` {symbol_list} is not supported."
            f"Supported `symbol_list` includes {available_symbol_set}."
        )


def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
    if phonemizer == "DeepPhonemizer":
        from dp.phonemizer import Phonemizer

        global _phonemizer
        _other_symbols = "".join(list(_special) + list(_punctuation))
        _phone_symbols_re = r"(\[[A-Z]+?\]|" + "[" + _other_symbols + "])"  # [\[([A-Z]+?)\]|[-!'(),.:;? ]]

        if _phonemizer is None:
            # using a global variable so that we don't have to relode checkpoint
            # everytime this function is called
            _phonemizer = Phonemizer.from_checkpoint(checkpoint)

        # Example:
        # sent = "hello world!"
        # '[HH][AH][L][OW] [W][ER][L][D]!'
        sent = _phonemizer(sent, lang="en_us")

        # ['[HH]', '[AH]', '[L]', '[OW]', ' ', '[W]', '[ER]', '[L]', '[D]', '!']
        ret = re.findall(_phone_symbols_re, sent)

        # ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!']
        ret = [r.replace("[", "").replace("]", "") for r in ret]

        return ret
    else:
        raise ValueError(
            f"The `phonemizer` {phonemizer} is not supported. " "Supported `symbol_list` includes `'DeepPhonemizer'`."
        )


def text_to_sequence(
    sent: str,
    symbol_list: Union[str, List[str]] = "english_characters",
    phonemizer: Optional[str] = "DeepPhonemizer",
    checkpoint: Optional[str] = "./en_us_cmudict_forward.pt",
    cmudict_root: Optional[str] = "./",
) -> List[int]:
    r"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.

    Args:
        sent (str): The input sentence to convert to a sequence.
        symbol_list (str or List of string, optional): When the input is a string, available options include
            "english_characters" and "english_phonemes". When the input is a list of string, ``symbol_list`` will
            directly be used as the symbol to encode. (Default: "english_characters")
        phonemizer (str or None, optional): The phonemizer to use. Only used when ``symbol_list`` is "english_phonemes".
            Available options include "DeepPhonemizer". (Default: "DeepPhonemizer")
        checkpoint (str or None, optional): The path to the checkpoint of the phonemizer. Only used when
            ``symbol_list`` is "english_phonemes". (Default: "./en_us_cmudict_forward.pt")
        cmudict_root (str or None, optional): The path to the directory where the CMUDict dataset is found or
            downloaded. Only used when ``symbol_list`` is "english_phonemes". (Default: "./")

    Returns:
        List of integers corresponding to the symbols in the sentence.

    Examples:
        >>> text_to_sequence("hello world!", "english_characters")
        [19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2]
        >>> text_to_sequence("hello world!", "english_phonemes")
        [54, 20, 65, 69, 11, 92, 44, 65, 38, 2]
    """
    if symbol_list == "english_phonemes":
        if any(param is None for param in [phonemizer, checkpoint, cmudict_root]):
            raise ValueError(
                "When `symbol_list` is 'english_phonemes', "
                "all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided."
            )

    sent = unidecode(sent)  # convert to ascii
    sent = sent.lower()  # lower case
    sent = normalize_numbers(sent)  # expand numbers
    for regex, replacement in _abbreviations:  # expand abbreviations
        sent = re.sub(regex, replacement, sent)
    sent = re.sub(_whitespace_re, " ", sent)  # collapse whitespace

    if isinstance(symbol_list, list):
        symbols = symbol_list
    elif isinstance(symbol_list, str):
        symbols = get_symbol_list(symbol_list, cmudict_root=cmudict_root)
        if symbol_list == "english_phonemes":
            sent = word_to_phonemes(sent, phonemizer=phonemizer, checkpoint=checkpoint)

    _symbol_to_id = {s: i for i, s in enumerate(symbols)}

    return [_symbol_to_id[s] for s in sent if s in _symbol_to_id]