File: entity_string.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (162 lines) | stat: -rw-r--r-- 5,224 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import Any, Generic, Iterable, Sequence, Type, TypeVar
from abc import ABC, abstractmethod
from itertools import chain

from attr import dataclass
import attr

from .formatted_string import EntityType, FormattedString


class AbstractEntity(ABC):
    def __init__(
        self, type: EntityType, offset: int, length: int, extra_info: dict[str, Any]
    ) -> None:
        pass

    @abstractmethod
    def copy(self) -> AbstractEntity:
        pass

    @abstractmethod
    def adjust_offset(self, offset: int, max_length: int = -1) -> AbstractEntity | None:
        pass


class SemiAbstractEntity(AbstractEntity, ABC):
    offset: int
    length: int

    def adjust_offset(self, offset: int, max_length: int = -1) -> SemiAbstractEntity | None:
        entity = self.copy()
        entity.offset += offset
        if entity.offset < 0:
            entity.length += entity.offset
            if entity.length < 0:
                return None
            entity.offset = 0
        elif entity.offset > max_length > -1:
            return None
        elif entity.offset + entity.length > max_length > -1:
            entity.length = max_length - entity.offset
        return entity


@dataclass
class SimpleEntity(SemiAbstractEntity):
    type: EntityType
    offset: int
    length: int
    extra_info: dict[str, Any] = attr.ib(factory=dict)

    def copy(self) -> SimpleEntity:
        return attr.evolve(self)


TEntity = TypeVar("TEntity", bound=AbstractEntity)
TEntityType = TypeVar("TEntityType")


class EntityString(Generic[TEntity, TEntityType], FormattedString):
    text: str
    _entities: list[TEntity]
    entity_class: Type[AbstractEntity] = SimpleEntity

    def __init__(self, text: str = "", entities: list[TEntity] = None) -> None:
        self.text = text
        self._entities = entities or []

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(text='{self.text}', entities={self.entities})"

    def __str__(self) -> str:
        return self.text

    @property
    def entities(self) -> list[TEntity]:
        return self._entities

    @entities.setter
    def entities(self, val: Iterable[TEntity]) -> None:
        self._entities = [entity for entity in val if entity is not None]

    def _offset_entities(self, offset: int) -> EntityString:
        self.entities = (entity.adjust_offset(offset, len(self.text)) for entity in self.entities)
        return self

    def append(self, *args: str | FormattedString) -> EntityString:
        for msg in args:
            if isinstance(msg, EntityString):
                self.entities += (entity.adjust_offset(len(self.text)) for entity in msg.entities)
                self.text += msg.text
            else:
                self.text += str(msg)
        return self

    def prepend(self, *args: str | FormattedString) -> EntityString:
        for msg in args:
            if isinstance(msg, EntityString):
                self.text = msg.text + self.text
                self.entities = chain(
                    msg.entities, (entity.adjust_offset(len(msg.text)) for entity in self.entities)
                )
            else:
                text = str(msg)
                self.text = text + self.text
                self.entities = (entity.adjust_offset(len(text)) for entity in self.entities)
        return self

    def format(
        self, entity_type: TEntityType, offset: int = None, length: int = None, **kwargs
    ) -> EntityString:
        self.entities.append(
            self.entity_class(
                type=entity_type,
                offset=offset or 0,
                length=length or len(self.text),
                extra_info=kwargs,
            )
        )
        return self

    def trim(self) -> EntityString:
        orig_len = len(self.text)
        self.text = self.text.lstrip()
        diff = orig_len - len(self.text)
        self.text = self.text.rstrip()
        self._offset_entities(-diff)
        return self

    def split(self, separator, max_items: int = -1) -> list[EntityString]:
        text_parts = self.text.split(separator, max_items - 1)
        output: list[EntityString] = []

        offset = 0
        for part in text_parts:
            msg = type(self)(part)
            msg.entities = (entity.adjust_offset(-offset, len(part)) for entity in self.entities)
            output.append(msg)

            offset += len(part)
            offset += len(separator)

        return output

    @classmethod
    def join(cls, items: Sequence[str | EntityString], separator: str = " ") -> EntityString:
        main = cls()
        for msg in items:
            if not isinstance(msg, EntityString):
                msg = cls(text=str(msg))
            main.entities += [entity.adjust_offset(len(main.text)) for entity in msg.entities]
            main.text += msg.text + separator
        if len(separator) > 0:
            main.text = main.text[: -len(separator)]
        return main