1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Generic, Iterable, Sequence, Type, TypeVar
from abc import ABC, abstractmethod
from itertools import chain
from attr import dataclass
import attr
from .formatted_string import EntityType, FormattedString
class AbstractEntity(ABC):
def __init__(
self, type: EntityType, offset: int, length: int, extra_info: dict[str, Any]
) -> None:
pass
@abstractmethod
def copy(self) -> AbstractEntity:
pass
@abstractmethod
def adjust_offset(self, offset: int, max_length: int = -1) -> AbstractEntity | None:
pass
class SemiAbstractEntity(AbstractEntity, ABC):
offset: int
length: int
def adjust_offset(self, offset: int, max_length: int = -1) -> SemiAbstractEntity | None:
entity = self.copy()
entity.offset += offset
if entity.offset < 0:
entity.length += entity.offset
if entity.length < 0:
return None
entity.offset = 0
elif entity.offset > max_length > -1:
return None
elif entity.offset + entity.length > max_length > -1:
entity.length = max_length - entity.offset
return entity
@dataclass
class SimpleEntity(SemiAbstractEntity):
type: EntityType
offset: int
length: int
extra_info: dict[str, Any] = attr.ib(factory=dict)
def copy(self) -> SimpleEntity:
return attr.evolve(self)
TEntity = TypeVar("TEntity", bound=AbstractEntity)
TEntityType = TypeVar("TEntityType")
class EntityString(Generic[TEntity, TEntityType], FormattedString):
text: str
_entities: list[TEntity]
entity_class: Type[AbstractEntity] = SimpleEntity
def __init__(self, text: str = "", entities: list[TEntity] = None) -> None:
self.text = text
self._entities = entities or []
def __repr__(self) -> str:
return f"{self.__class__.__name__}(text='{self.text}', entities={self.entities})"
def __str__(self) -> str:
return self.text
@property
def entities(self) -> list[TEntity]:
return self._entities
@entities.setter
def entities(self, val: Iterable[TEntity]) -> None:
self._entities = [entity for entity in val if entity is not None]
def _offset_entities(self, offset: int) -> EntityString:
self.entities = (entity.adjust_offset(offset, len(self.text)) for entity in self.entities)
return self
def append(self, *args: str | FormattedString) -> EntityString:
for msg in args:
if isinstance(msg, EntityString):
self.entities += (entity.adjust_offset(len(self.text)) for entity in msg.entities)
self.text += msg.text
else:
self.text += str(msg)
return self
def prepend(self, *args: str | FormattedString) -> EntityString:
for msg in args:
if isinstance(msg, EntityString):
self.text = msg.text + self.text
self.entities = chain(
msg.entities, (entity.adjust_offset(len(msg.text)) for entity in self.entities)
)
else:
text = str(msg)
self.text = text + self.text
self.entities = (entity.adjust_offset(len(text)) for entity in self.entities)
return self
def format(
self, entity_type: TEntityType, offset: int = None, length: int = None, **kwargs
) -> EntityString:
self.entities.append(
self.entity_class(
type=entity_type,
offset=offset or 0,
length=length or len(self.text),
extra_info=kwargs,
)
)
return self
def trim(self) -> EntityString:
orig_len = len(self.text)
self.text = self.text.lstrip()
diff = orig_len - len(self.text)
self.text = self.text.rstrip()
self._offset_entities(-diff)
return self
def split(self, separator, max_items: int = -1) -> list[EntityString]:
text_parts = self.text.split(separator, max_items - 1)
output: list[EntityString] = []
offset = 0
for part in text_parts:
msg = type(self)(part)
msg.entities = (entity.adjust_offset(-offset, len(part)) for entity in self.entities)
output.append(msg)
offset += len(part)
offset += len(separator)
return output
@classmethod
def join(cls, items: Sequence[str | EntityString], separator: str = " ") -> EntityString:
main = cls()
for msg in items:
if not isinstance(msg, EntityString):
msg = cls(text=str(msg))
main.entities += [entity.adjust_offset(len(main.text)) for entity in msg.entities]
main.text += msg.text + separator
if len(separator) > 0:
main.text = main.text[: -len(separator)]
return main
|