1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
#
# Copyright (c), 2024, SISSA (International School for Advanced Studies).
# All rights reserved.
# This file is distributed under the terms of the MIT License.
# See the file 'LICENSE' in the root directory of the present
# distribution, or http://opensource.org/licenses/MIT.
#
# @author Davide Brunato <brunato@sissa.it>
#
from io import BufferedIOBase
from threading import Lock
from typing import Any, Optional, Union
DEFAULT_BUFFER_SIZE = 8 * 1024
def is_file_object(obj: object) -> bool:
return hasattr(obj, 'read') and hasattr(obj, 'seekable') \
and hasattr(obj, 'tell') and hasattr(obj, 'seek') \
and hasattr(obj, 'closed') and hasattr(obj, 'close')
class DefusableReader(BufferedIOBase):
"""
A class for wrapping a not seekable buffered IO stream in a partially seekable
stream that can be defused. The initial buffer size is 64KiB and can't be lower
than io.DEFAULT_BUFFER_SIZE.
"""
def __init__(self, fp: BufferedIOBase, initial_buffer_size: int = 64 * 1024):
if not isinstance(fp, BufferedIOBase):
raise TypeError(
f'"fp" argument must an instance of {BufferedIOBase} or a derived class'
)
if not fp.readable():
raise OSError('"fp" argument must be readable')
if fp.closed:
raise OSError('"fp" argument must be a not closed file descriptor')
if initial_buffer_size < DEFAULT_BUFFER_SIZE:
initial_buffer_size = DEFAULT_BUFFER_SIZE
buf = bytearray()
buf += fp.read(initial_buffer_size)
self._buffer = buf
self._buffer_size = len(buf)
self._fp = fp
self._pos = 0
self._fp_lock = Lock()
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
state.pop('_fp_lock', None)
return state
def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
self._xpath_lock = Lock()
def readable(self) -> bool:
return self._fp.readable()
def seekable(self) -> bool:
self._checkClosed()
return self._pos < self._buffer_size or self._fp.seekable()
def seek(self, pos: int, whence: int = 0) -> int:
if self.closed:
raise ValueError("seek on closed file")
if not isinstance(pos, int):
raise TypeError(f"{pos!r} is not an integer")
with self._fp_lock:
if whence == 0:
if pos < 0:
raise ValueError(f"negative seek position {pos!r}")
elif whence == 1:
pos = max(0, self._pos + pos)
elif whence == 2:
pos = self._fp.seek(pos, 2)
else:
raise ValueError("unsupported whence value")
if pos > self._buffer_size and whence != 2:
pos = self._fp.seek(pos)
elif self._pos > self._buffer_size:
self._fp.seek(self._buffer_size)
self._pos = pos
return self._pos
def tell(self) -> int:
return self._pos
def getbuffer(self) -> memoryview:
if self.closed:
raise ValueError("getbuffer on closed file")
return memoryview(self._buffer)
def close(self) -> None:
with self._fp_lock:
self._buffer.clear()
self._fp.close()
def read(self, size: Optional[int] = None) -> bytes:
self._checkClosed()
if size is not None and size < -1:
raise ValueError("invalid number of bytes to read")
with self._fp_lock:
return self._read_unlocked(size)
def _read_unlocked(self, size: Optional[int] = None) -> bytes:
if self._pos >= self._buffer_size:
data = self._fp.read(size)
self._pos += len(data)
return data
buffer = self._buffer[self._pos:]
if size is not None and size > -1:
if size <= len(buffer):
data = buffer[:size]
else:
chunk = self._fp.read(size - len(buffer))
data = buffer if not chunk else buffer + chunk
elif hasattr(self._fp, 'readall'):
chunk = self._fp.readall()
data = buffer if not chunk else buffer + chunk
else:
chunks: list[Union[bytes, bytearray]] = [buffer]
while True:
chunk = self._fp.read()
if not chunk:
break
chunks.append(chunk)
data = b"".join(chunks)
self._pos += len(data)
return data
def read1(self, size: int = -1) -> bytes:
return self.read(size)
|