1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
import atheris # pragma: no cover
@atheris.instrument_func
def is_expected_exception(
error_message_list: list[str], exception: Exception
) -> bool: # pragma: no cover
"""Checks if the message of a given exception matches any of the expected error messages.
Args:
error_message_list (List[str]): A list of error message substrings to check against the exception's message.
exception (Exception): The exception object raised during execution.
Returns:
bool: True if the exception's message contains any of the substrings from the error_message_list, otherwise False.
"""
for error in error_message_list:
if error in str(exception):
return True
return False
class EnhancedFuzzedDataProvider(atheris.FuzzedDataProvider): # pragma: no cover
"""Extends atheris.FuzzedDataProvider to offer additional methods to make fuzz testing slightly more DRY."""
def __init__(self, data) -> None:
"""Initializes the EnhancedFuzzedDataProvider with fuzzing data from the argument provided to TestOneInput.
Args:
data (bytes): The binary data used for fuzzing.
"""
super().__init__(data)
def ConsumeRemainingBytes(self) -> bytes:
"""Consume the remaining bytes in the bytes container.
Returns:
bytes: Zero or more bytes.
"""
return self.ConsumeBytes(self.remaining_bytes())
def ConsumeRandomBytes(self, max_length=None) -> bytes:
"""Consume a random count of bytes from the bytes container.
Args:
max_length (int, optional): The maximum length of the string. Defaults to the number of remaining bytes.
Returns:
bytes: Zero or more bytes.
"""
if max_length is None:
max_length = self.remaining_bytes()
else:
max_length = min(max_length, self.remaining_bytes())
return self.ConsumeBytes(self.ConsumeIntInRange(0, max_length))
def ConsumeRandomString(self, max_length=None, without_surrogates=False) -> str:
"""Consume bytes to produce a Unicode string.
Args:
max_length (int, optional): The maximum length of the string. Defaults to the number of remaining bytes.
without_surrogates (bool, optional): If True, never generate surrogate pair characters. Defaults to False.
Returns:
str: A Unicode string.
"""
if max_length is None:
max_length = self.remaining_bytes()
else:
max_length = min(max_length, self.remaining_bytes())
count = self.ConsumeIntInRange(0, max_length)
if without_surrogates:
return self.ConsumeUnicodeNoSurrogates(count)
else:
return self.ConsumeUnicode(count)
def ConsumeRandomInt(self, minimum=0, maximum=1234567890) -> int:
"""Consume bytes to produce an integer.
Args:
minimum (int, optional): The minimum value of the integer. Defaults to 0.
maximum (int, optional): The maximum value of the integer. Defaults to 1234567890.
Returns:
int: An integer.
"""
return self.ConsumeIntInRange(minimum, maximum)
|