1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
import unittest
from unittest.mock import patch
from papermill.exceptions import PapermillRateLimitException
from papermill.iorw import GCSHandler, fallback_gs_is_retriable
try:
try:
try:
from gcsfs.retry import HttpError as GCSHttpError
except ImportError:
from gcsfs.utils import HttpError as GCSHttpError
except ImportError:
from gcsfs.utils import HtmlError as GCSHttpError
except ImportError:
# Fall back to a sane import if gcsfs is missing
GCSHttpError = Exception
try:
from gcsfs.utils import RateLimitException as GCSRateLimitException
except ImportError:
# Fall back to GCSHttpError when using older library
GCSRateLimitException = GCSHttpError
def mock_gcs_fs_wrapper(exception=None, max_raises=1):
class MockGCSFileSystem:
def __init__(self):
self._file = MockGCSFile(exception, max_raises)
def open(self, *args, **kwargs):
return self._file
def ls(self, *args, **kwargs):
return []
return MockGCSFileSystem
class MockGCSFile:
def __init__(self, exception=None, max_raises=1):
self.read_count = 0
self.write_count = 0
self.exception = exception
self.max_raises = max_raises
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
pass
def read(self):
self.read_count += 1
if self.exception and self.read_count <= self.max_raises:
raise self.exception
return self.read_count
def write(self, buf):
self.write_count += 1
if self.exception and self.write_count <= self.max_raises:
raise self.exception
return self.write_count
class GCSTest(unittest.TestCase):
"""Tests for `GCS`."""
def setUp(self):
self.gcs_handler = GCSHandler()
@patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper())
def test_gcs_read(self, mock_gcs_filesystem):
client = self.gcs_handler._get_client()
self.assertEqual(self.gcs_handler.read('gs://bucket/test.ipynb'), 1)
# Check that client is only generated once
self.assertIs(client, self.gcs_handler._get_client())
@patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper())
def test_gcs_write(self, mock_gcs_filesystem):
client = self.gcs_handler._get_client()
self.assertEqual(self.gcs_handler.write('new value', 'gs://bucket/test.ipynb'), 1)
# Check that client is only generated once
self.assertIs(client, self.gcs_handler._get_client())
@patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper())
def test_gcs_listdir(self, mock_gcs_filesystem):
client = self.gcs_handler._get_client()
self.gcs_handler.listdir('testdir')
# Check that client is only generated once
self.assertIs(client, self.gcs_handler._get_client())
@patch(
'papermill.iorw.GCSFileSystem',
side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": 429}), 10),
)
def test_gcs_handle_exception(self, mock_gcs_filesystem):
with patch.object(GCSHandler, 'RETRY_DELAY', 0):
with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
with self.assertRaises(PapermillRateLimitException):
self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb')
@patch(
'papermill.iorw.GCSFileSystem',
side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": 429}), 1),
)
def test_gcs_retry(self, mock_gcs_filesystem):
with patch.object(GCSHandler, 'RETRY_DELAY', 0):
with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
self.assertEqual(self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'), 2)
@patch(
'papermill.iorw.GCSFileSystem',
side_effect=mock_gcs_fs_wrapper(GCSHttpError({"message": "test", "code": 429}), 1),
)
def test_gcs_retry_older_exception(self, mock_gcs_filesystem):
with patch.object(GCSHandler, 'RETRY_DELAY', 0):
with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
self.assertEqual(self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'), 2)
@patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable)
@patch(
'papermill.iorw.GCSFileSystem',
side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": None}), 1),
)
def test_gcs_fallback_retry_unknown_failure_code(self, mock_gcs_filesystem, mock_gcs_retriable):
with patch.object(GCSHandler, 'RETRY_DELAY', 0):
with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
self.assertEqual(self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'), 2)
@patch('papermill.iorw.gs_is_retriable', return_value=False)
@patch(
'papermill.iorw.GCSFileSystem',
side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": 500}), 1),
)
def test_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable):
with self.assertRaises(GCSRateLimitException):
self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb')
@patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable)
@patch(
'papermill.iorw.GCSFileSystem',
side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": 500}), 1),
)
def test_fallback_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable):
with self.assertRaises(GCSRateLimitException):
self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb')
@patch(
'papermill.iorw.GCSFileSystem',
side_effect=mock_gcs_fs_wrapper(ValueError("not-a-retry"), 1),
)
def test_gcs_unretryable(self, mock_gcs_filesystem):
with self.assertRaises(ValueError):
self.gcs_handler.write('no_a_rate_limit', 'gs://bucket/test.ipynb')
|