File: test_s3.py

package info (click to toggle)
smart-open 7.5.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 980 kB
  • sloc: python: 8,054; sh: 90; makefile: 14
file content (145 lines) | stat: -rw-r--r-- 4,402 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Radim Rehurek <me@radimrehurek.com>
#
# This code is distributed under the terms and conditions
# from the MIT License (MIT).
#

from __future__ import unicode_literals
import contextlib
import io
import os
import random
import string

import boto3
import smart_open

_BUCKET = os.environ.get('SO_BUCKET')
assert _BUCKET is not None, 'please set the SO_BUCKET environment variable'

_KEY = os.environ.get('SO_KEY')
assert _KEY is not None, 'please set the SO_KEY environment variable'


#
# https://stackoverflow.com/questions/13484726/safe-enough-8-character-short-unique-random-string
#
def _random_string(length=8):
    alphabet = string.ascii_lowercase + string.digits
    return ''.join(random.choices(alphabet, k=length))


@contextlib.contextmanager
def temporary():
    """Yields a URL than can be used for temporary writing.

    Removes all content under the URL when exiting.
    """
    key = '%s/%s' % (_KEY, _random_string())
    yield 's3://%s/%s' % (_BUCKET, key)
    boto3.resource('s3').Bucket(_BUCKET).objects.filter(Prefix=key).delete()


def _test_case(function):
    def inner(benchmark):
        with temporary() as uri:
            return function(benchmark, uri)
    return inner


def write_read(uri, content, write_mode, read_mode, encoding=None, s3_upload=None, **kwargs):
    write_params = dict(kwargs)
    write_params.update(s3_upload=s3_upload)
    with smart_open.open(uri, write_mode, encoding=encoding, transport_params=write_params) as fout:
        fout.write(content)
    with smart_open.open(uri, read_mode, encoding=encoding, transport_params=kwargs) as fin:
        actual = fin.read()
    return actual


def read_length_prefixed_messages(uri, read_mode, encoding=None, **kwargs):
    with smart_open.open(uri, read_mode, encoding=encoding, transport_params=kwargs) as fin:
        actual = b''
        length_byte = fin.read(1)
        while len(length_byte):
            actual += length_byte
            msg = fin.read(ord(length_byte))
            actual += msg
            length_byte = fin.read(1)
    return actual


@_test_case
def test_s3_readwrite_text(benchmark, uri):
    text = 'с гранатою в кармане, с чекою в руке'
    actual = benchmark(write_read, uri, text, 'w', 'r', 'utf-8')
    assert actual == text


@_test_case
def test_s3_readwrite_text_gzip(benchmark, uri):
    text = 'не чайки здесь запели на знакомом языке'
    actual = benchmark(write_read, uri, text, 'w', 'r', 'utf-8')
    assert actual == text


@_test_case
def test_s3_readwrite_binary(benchmark, uri):
    binary = b'this is a test'
    actual = benchmark(write_read, uri, binary, 'wb', 'rb')
    assert actual == binary


@_test_case
def test_s3_readwrite_binary_gzip(benchmark, uri):
    binary = b'this is a test'
    actual = benchmark(write_read, uri, binary, 'wb', 'rb')
    assert actual == binary


@_test_case
def test_s3_performance(benchmark, uri):
    one_megabyte = io.BytesIO()
    for _ in range(1024*128):
        one_megabyte.write(b'01234567')
    one_megabyte = one_megabyte.getvalue()

    actual = benchmark(write_read, uri, one_megabyte, 'wb', 'rb')
    assert actual == one_megabyte


@_test_case
def test_s3_performance_gz(benchmark, uri):
    one_megabyte = io.BytesIO()
    for _ in range(1024*128):
        one_megabyte.write(b'01234567')
    one_megabyte = one_megabyte.getvalue()

    actual = benchmark(write_read, uri, one_megabyte, 'wb', 'rb')
    assert actual == one_megabyte


@_test_case
def test_s3_performance_small_reads(benchmark, uri):
    one_mib = 1024**2
    one_megabyte_of_msgs = io.BytesIO()
    msg = b'\x0f' + b'0123456789abcde'  # a length-prefixed "message"
    for _ in range(0, one_mib, len(msg)):
        one_megabyte_of_msgs.write(msg)
    one_megabyte_of_msgs = one_megabyte_of_msgs.getvalue()

    with smart_open.open(uri, 'wb') as fout:
        fout.write(one_megabyte_of_msgs)

    actual = benchmark(read_length_prefixed_messages, uri, 'rb', buffer_size=one_mib)
    assert actual == one_megabyte_of_msgs


@_test_case
def test_s3_encrypted_file(benchmark, uri):
    text = 'с гранатою в кармане, с чекою в руке'
    s3_upload = {'ServerSideEncryption': 'AES256'}
    actual = benchmark(write_read, uri, text, 'w', 'r', 'utf-8', s3_upload=s3_upload)
    assert actual == text