File: test_gcs.py

package info (click to toggle)
python-papermill 2.6.0-3.1
  • links: PTS, VCS
  • area: main
  • in suites: forky, trixie
  • size: 2,216 kB
  • sloc: python: 4,977; makefile: 17; sh: 5
file content (159 lines) | stat: -rw-r--r-- 6,452 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import unittest
from unittest.mock import patch

from papermill.exceptions import PapermillRateLimitException
from papermill.iorw import GCSHandler, fallback_gs_is_retriable

try:
    try:
        try:
            from gcsfs.retry import HttpError as GCSHttpError
        except ImportError:
            from gcsfs.utils import HttpError as GCSHttpError
    except ImportError:
        from gcsfs.utils import HtmlError as GCSHttpError
except ImportError:
    # Fall back to a sane import if gcsfs is missing
    GCSHttpError = Exception

try:
    from gcsfs.utils import RateLimitException as GCSRateLimitException
except ImportError:
    # Fall back to GCSHttpError when using older library
    GCSRateLimitException = GCSHttpError


def mock_gcs_fs_wrapper(exception=None, max_raises=1):
    class MockGCSFileSystem:
        def __init__(self):
            self._file = MockGCSFile(exception, max_raises)

        def open(self, *args, **kwargs):
            return self._file

        def ls(self, *args, **kwargs):
            return []

    return MockGCSFileSystem


class MockGCSFile:
    def __init__(self, exception=None, max_raises=1):
        self.read_count = 0
        self.write_count = 0
        self.exception = exception
        self.max_raises = max_raises

    def __enter__(self):
        return self

    def __exit__(self, *args, **kwargs):
        pass

    def read(self):
        self.read_count += 1
        if self.exception and self.read_count <= self.max_raises:
            raise self.exception
        return self.read_count

    def write(self, buf):
        self.write_count += 1
        if self.exception and self.write_count <= self.max_raises:
            raise self.exception
        return self.write_count


class GCSTest(unittest.TestCase):
    """Tests for `GCS`."""

    def setUp(self):
        self.gcs_handler = GCSHandler()

    @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper())
    def test_gcs_read(self, mock_gcs_filesystem):
        client = self.gcs_handler._get_client()
        self.assertEqual(self.gcs_handler.read('gs://bucket/test.ipynb'), 1)
        # Check that client is only generated once
        self.assertIs(client, self.gcs_handler._get_client())

    @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper())
    def test_gcs_write(self, mock_gcs_filesystem):
        client = self.gcs_handler._get_client()
        self.assertEqual(self.gcs_handler.write('new value', 'gs://bucket/test.ipynb'), 1)
        # Check that client is only generated once
        self.assertIs(client, self.gcs_handler._get_client())

    @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper())
    def test_gcs_listdir(self, mock_gcs_filesystem):
        client = self.gcs_handler._get_client()
        self.gcs_handler.listdir('testdir')
        # Check that client is only generated once
        self.assertIs(client, self.gcs_handler._get_client())

    @patch(
        'papermill.iorw.GCSFileSystem',
        side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": 429}), 10),
    )
    def test_gcs_handle_exception(self, mock_gcs_filesystem):
        with patch.object(GCSHandler, 'RETRY_DELAY', 0):
            with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
                with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
                    with self.assertRaises(PapermillRateLimitException):
                        self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb')

    @patch(
        'papermill.iorw.GCSFileSystem',
        side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": 429}), 1),
    )
    def test_gcs_retry(self, mock_gcs_filesystem):
        with patch.object(GCSHandler, 'RETRY_DELAY', 0):
            with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
                with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
                    self.assertEqual(self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'), 2)

    @patch(
        'papermill.iorw.GCSFileSystem',
        side_effect=mock_gcs_fs_wrapper(GCSHttpError({"message": "test", "code": 429}), 1),
    )
    def test_gcs_retry_older_exception(self, mock_gcs_filesystem):
        with patch.object(GCSHandler, 'RETRY_DELAY', 0):
            with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
                with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
                    self.assertEqual(self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'), 2)

    @patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable)
    @patch(
        'papermill.iorw.GCSFileSystem',
        side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": None}), 1),
    )
    def test_gcs_fallback_retry_unknown_failure_code(self, mock_gcs_filesystem, mock_gcs_retriable):
        with patch.object(GCSHandler, 'RETRY_DELAY', 0):
            with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
                with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
                    self.assertEqual(self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'), 2)

    @patch('papermill.iorw.gs_is_retriable', return_value=False)
    @patch(
        'papermill.iorw.GCSFileSystem',
        side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": 500}), 1),
    )
    def test_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable):
        with self.assertRaises(GCSRateLimitException):
            self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb')

    @patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable)
    @patch(
        'papermill.iorw.GCSFileSystem',
        side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({"message": "test", "code": 500}), 1),
    )
    def test_fallback_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable):
        with self.assertRaises(GCSRateLimitException):
            self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb')

    @patch(
        'papermill.iorw.GCSFileSystem',
        side_effect=mock_gcs_fs_wrapper(ValueError("not-a-retry"), 1),
    )
    def test_gcs_unretryable(self, mock_gcs_filesystem):
        with self.assertRaises(ValueError):
            self.gcs_handler.write('no_a_rate_limit', 'gs://bucket/test.ipynb')