1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
|
import asyncio
import json
import logging
import random
import uuid
from collections import namedtuple
from json import JSONDecodeError
from os import getenv
from typing import Any, Callable, List, Optional, Dict, Union
from AnyQt.QtCore import QSettings
from httpx import AsyncClient, NetworkError, ReadTimeout, Response, AsyncHTTPTransport
from numpy import linspace
from Orange.misc.utils.embedder_utils import (
EmbedderCache,
EmbeddingConnectionError,
get_proxies,
)
from Orange.util import dummy_callback
log = logging.getLogger(__name__)
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
def _rewrite_proxies_to_mounts(proxies):
if proxies is None:
return None
return {c: AsyncHTTPTransport(proxy=url) for c, url in get_proxies().items()}
class ServerEmbedderCommunicator:
"""
This class needs to be inherited by the class which re-implements
_encode_data_instance and defines self.content_type. For sending a list
with data items use embedd_table function.
Attributes
----------
model_name
The name of the model. Name is used in url to define what server model
gets data to embedd and as a caching keyword.
max_parallel_requests
Number of image that can be sent to the server at the same time.
server_url
The base url of the server (without any additional subdomains)
embedder_type
The type of embedder (e.g. image). It is used as a part of url (e.g.
when embedder_type is image url is api.garaza.io/image)
"""
MAX_REPEATS = 3
count_connection_errors = 0
count_read_errors = 0
max_errors = 10
def __init__(
self,
model_name: str,
max_parallel_requests: int,
server_url: str,
embedder_type: str,
) -> None:
self.server_url = getenv("ORANGE_EMBEDDING_API_URL", server_url)
self._model = model_name
self.embedder_type = embedder_type
self.machine_id = None
try:
self.machine_id = QSettings().value(
"error-reporting/machine-id", "", type=str
) or str(uuid.getnode())
except TypeError:
self.machine_id = str(uuid.getnode())
self.session_id = str(random.randint(1, int(1e10)))
self._cache = EmbedderCache(model_name)
# default embedding timeouts are too small we need to increase them
self.timeout = 180
self.max_parallel_requests = max_parallel_requests
self.content_type = None # need to be set in a class inheriting
def embedd_data(
self,
data: List[Any],
*,
callback: Callable = dummy_callback,
) -> List[Optional[List[float]]]:
"""
This function repeats calling embedding function until all items
are embedded. It prevents skipped items due to network issues.
The process is repeated for each item maximally MAX_REPEATS times.
Parameters
----------
data
List with data that needs to be embedded.
callback
Callback for reporting the progress in share of embedded items
Returns
-------
List of float list (embeddings) for successfully embedded
items and Nones for skipped items.
Raises
------
EmbeddingConnectionError
Error which indicate that the embedding is not possible due to
connection error.
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
# if there is less items than 10 connection error should be raised earlier
self.max_errors = min(len(data) * self.MAX_REPEATS, 10)
return asyncio.run(
self.embedd_batch(data, callback=callback)
)
async def embedd_batch(
self,
data: List[Any],
*,
callback: Callable = dummy_callback,
) -> List[Optional[List[float]]]:
"""
Function perform embedding of a batch of data items.
Parameters
----------
data
A list of data that must be embedded.
callback
Callback for reporting the progress in share of embedded items
Returns
-------
List of float list (embeddings) for successfully embedded
items and Nones for skipped items.
Raises
------
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
progress_items = iter(linspace(0, 1, len(data)))
def success_callback():
"""Callback called on every successful embedding"""
callback(next(progress_items))
results = [None] * len(data)
queue = asyncio.Queue()
# fill the queue with items to embedd
for i, item in enumerate(data):
queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0))
proxy_mounts = _rewrite_proxies_to_mounts(get_proxies())
async with AsyncClient(
timeout=self.timeout, base_url=self.server_url, mounts=proxy_mounts
) as client:
tasks = self._init_workers(client, queue, results, success_callback)
try:
# wait for workers to stop - they stop when queue is empty
# if one worker raises exception wait will raise it further
await asyncio.gather(*tasks)
finally:
await self._cancel_workers(tasks)
self._cache.persist_cache()
return results
def _init_workers(self, client, queue, results, callback):
"""Init required number of workers"""
t = [
asyncio.create_task(self._send_to_server(client, queue, results, callback))
# when number of instances less than max_parallel_requests create
# only required number of workers
for _ in range(min(self.max_parallel_requests, len(results)))
]
log.debug("Created %d workers", self.max_parallel_requests)
return t
@staticmethod
async def _cancel_workers(tasks):
"""Cancel worker at the end"""
log.debug("Canceling workers")
# cancel all tasks in both cases
for task in tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*tasks, return_exceptions=True)
log.debug("All workers canceled")
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
"""
The reimplementation of this function must implement the procedure
to encode the data item in a string format that will be sent to the
server. For images it is the byte string with an image. The encoding
must be always equal for same data instance.
Parameters
----------
data_instance
The row of an Orange data table.
Returns
-------
Bytes encoding the data instance.
"""
raise NotImplementedError
async def _send_to_server(
self,
client: AsyncClient,
queue: asyncio.Queue,
results: List,
proc_callback: Callable,
):
"""
Worker that embedds data. It is pulling items from the queue until
it is empty. It is runs until anything is in the queue, or it is canceled
Parameters
----------
client
HTTPX client that communicates with the server
queue
The queue with items of type TaskItem to be embedded
results
The list to append results in. The list has length equal to numbers
of all items to embedd. The result need to be inserted at the index
defined in queue items.
proc_callback
A function that is called after each item is fully processed
by either getting a successful response from the server,
getting the result from cache or skipping the item.
"""
while not queue.empty():
# get item from the queue
i, data_instance, num_repeats = await queue.get()
# load bytes
data_bytes = await self._encode_data_instance(data_instance)
if data_bytes is None:
continue
# retrieve embedded item from the local cache
cache_key = self._cache.md5_hash(data_bytes)
log.debug("Embedding %s", cache_key)
emb = self._cache.get_cached_result_or_none(cache_key)
if emb is None:
# send the item to the server for embedding if not in the local cache
log.debug("Sending to the server: %s", cache_key)
url = (
f"/{self.embedder_type}/{self._model}?machine={self.machine_id}"
f"&session={self.session_id}&retry={num_repeats+1}"
)
emb = await self._send_request(client, data_bytes, url)
if emb is not None:
self._cache.add(cache_key, emb)
if emb is not None:
# store result if embedding is successful
log.debug("Successfully embedded: %s", cache_key)
results[i] = emb
proc_callback()
elif num_repeats+1 < self.MAX_REPEATS:
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
# if embedding not successful put the item to queue to be handled at
# the end - the item is put to the end since it is possible that server
# still process the request and the result will be in the cache later
# repeating the request immediately may result in another fail when
# processing takes longer
queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats+1))
queue.task_done()
async def _send_request(
self, client: AsyncClient, data: Union[bytes, Dict], url: str
) -> Optional[List[float]]:
"""
This function sends a single request to the server.
Parameters
----------
client
HTTPX client that communicates with the server
data
Data packed in the sequence of bytes.
url
Rest of the url string.
Returns
-------
embedding
Embedding. For items that are not successfully embedded returns
None.
"""
headers = {
"Content-Type": self.content_type,
"Content-Length": str(len(data)),
}
try:
# bytes are sent as content parameter and dictionary as data
kwargs = dict(content=data) if isinstance(data, bytes) else dict(data=data)
response = await client.post(url, headers=headers, **kwargs)
except ReadTimeout as ex:
log.debug("Read timeout", exc_info=True)
# it happens when server do not respond in time defined by timeout
# return None and items will be resend later
# if it happens more than in ten consecutive cases it means
# sth is wrong with embedder we stop embedding
self.count_read_errors += 1
if self.count_read_errors >= self.max_errors:
raise EmbeddingConnectionError from ex
return None
except (OSError, NetworkError) as ex:
log.debug("Network error", exc_info=True)
# it happens when no connection and items cannot be sent to server
# if more than 10 consecutive errors it means there is no
# connection so we stop embedding with EmbeddingConnectionError
self.count_connection_errors += 1
if self.count_connection_errors >= self.max_errors:
raise EmbeddingConnectionError from ex
return None
except Exception:
log.debug("Embedding error", exc_info=True)
raise
# we reset the counter at successful embedding
self.count_connection_errors = 0
self.count_read_errors = 0
return self._parse_response(response)
@staticmethod
def _parse_response(response: Response) -> Optional[List[float]]:
"""
This function get response and extract embeddings out of them.
Parameters
----------
response
Response by the server
Returns
-------
Embedding. For items that are not successfully embedded returns None.
"""
if response.content:
try:
cont = json.loads(response.content.decode("utf-8"))
return cont.get("embedding", None)
except JSONDecodeError:
# in case that embedding was not successful response is not
# valid JSON
return None
else:
return None
def clear_cache(self):
self._cache.clear_cache()
|