1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
|
import contextlib
import datetime
import io
import tempfile
from typing import List, TypeVar
import atheris
T = TypeVar("T")
class EnhancedFuzzedDataProvider(atheris.FuzzedDataProvider):
def ConsumeRandomBytes(self) -> bytes:
return self.ConsumeBytes(self.ConsumeIntInRange(0, self.remaining_bytes()))
def ConsumeRandomString(self) -> str:
return self.ConsumeUnicodeNoSurrogates(
self.ConsumeIntInRange(0, self.remaining_bytes())
)
def ConsumeRemainingString(self) -> str:
return self.ConsumeUnicodeNoSurrogates(self.remaining_bytes())
def ConsumeRemainingBytes(self) -> bytes:
return self.ConsumeBytes(self.remaining_bytes())
def ConsumeSublist(self, source: List[T]) -> List[T]:
"""
Returns a shuffled sub-list of the given list of len [1, len(source)]
"""
chosen = [elem for elem in source if self.ConsumeBool()]
# Shuffle
for i in range(len(chosen) - 1, 1, -1):
j = self.ConsumeIntInRange(0, i)
chosen[i], chosen[j] = chosen[j], chosen[i]
return chosen or [self.PickValueInList(source)]
def ConsumeDate(self) -> datetime.datetime:
try:
return datetime.datetime.fromtimestamp(self.ConsumeFloat())
except (OverflowError, OSError, ValueError):
return datetime.datetime(year=1970, month=1, day=1)
@contextlib.contextmanager
def ConsumeMemoryFile(
self, all_data: bool = False, as_bytes: bool = True
) -> io.BytesIO:
if all_data:
file_data = (
self.ConsumeRemainingBytes()
if as_bytes
else self.ConsumeRemainingString()
)
else:
file_data = (
self.ConsumeRandomBytes() if as_bytes else self.ConsumeRandomString()
)
file = io.BytesIO(file_data) if as_bytes else io.StringIO(file_data)
yield file
file.close()
@contextlib.contextmanager
def ConsumeTemporaryFile(
self, suffix: str, all_data: bool = False, as_bytes: bool = True
) -> str:
if all_data:
file_data = (
self.ConsumeRemainingBytes()
if as_bytes
else self.ConsumeRemainingString()
)
else:
file_data = (
self.ConsumeRandomBytes() if as_bytes else self.ConsumeRandomString()
)
mode = "w+b" if as_bytes else "w+"
tfile = tempfile.NamedTemporaryFile(mode=mode, suffix=suffix)
tfile.write(file_data)
tfile.seek(0)
tfile.flush()
yield tfile.name
tfile.close()
|