File: compression.py

package info (click to toggle)
smart-open 7.5.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 980 kB
  • sloc: python: 8,054; sh: 90; makefile: 14
file content (179 lines) | stat: -rw-r--r-- 5,533 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# -*- coding: utf-8 -*-
#
# Copyright (C) 2020 Radim Rehurek <me@radimrehurek.com>
#
# This code is distributed under the terms and conditions
# from the MIT License (MIT).
#
"""Implements the compression layer of the `smart_open` library."""
import io
import logging
import os.path

logger = logging.getLogger(__name__)

_COMPRESSOR_REGISTRY = {}

NO_COMPRESSION = 'disable'
"""Use no compression. Read/write the data as-is."""
INFER_FROM_EXTENSION = 'infer_from_extension'
"""Determine the compression to use from the file extension.

See get_supported_extensions().
"""


def get_supported_compression_types():
    """Return the list of supported compression types available to open.

    See compression paratemeter to smart_open.open().
    """
    return [NO_COMPRESSION, INFER_FROM_EXTENSION] + get_supported_extensions()


def get_supported_extensions():
    """Return the list of file extensions for which we have registered compressors."""
    return sorted(_COMPRESSOR_REGISTRY.keys())


def register_compressor(ext, callback):
    """Register a callback for transparently decompressing files with a specific extension.

    Parameters
    ----------
    ext: str
        The extension.  Must include the leading period, e.g. `.gz`.
    callback: callable
        The callback.  It must accept two position arguments, file_obj and mode.
        This function will be called when `smart_open` is opening a file with
        the specified extension.

    Examples
    --------

    Instruct smart_open to use the `lzma` module whenever opening a file
    with a .xz extension (see README.rst for the complete example showing I/O):

    >>> def _handle_xz(file_obj, mode):
    ...     import lzma
    ...     return lzma.LZMAFile(filename=file_obj, mode=mode)
    >>>
    >>> register_compressor('.xz', _handle_xz)

    This is just an example: `lzma` is in the standard library and is registered by default.

    """
    if not (ext and ext[0] == '.'):
        raise ValueError('ext must be a string starting with ., not %r' % ext)
    ext = ext.lower()
    if ext in _COMPRESSOR_REGISTRY:
        logger.warning('overriding existing compression handler for %r', ext)
    _COMPRESSOR_REGISTRY[ext] = callback


def tweak_close(outer, inner):
    """Ensure that closing the `outer` stream closes the `inner` stream as well.

    Deprecated: `smart_open.open().__exit__` now always calls `__exit__` on the
    underlying filestream.

    Use this when your compression library's `close` method does not
    automatically close the underlying filestream.  See
    https://github.com/piskvorky/smart_open/issues/630 for an
    explanation why that is a problem for smart_open.
    """
    outer_close = outer.close

    def close_both(*args):
        nonlocal inner
        try:
            outer_close()
        finally:
            if inner:
                inner, fp = None, inner
                fp.close()

    outer.close = close_both


def _maybe_wrap_buffered(file_obj, mode):
    # https://github.com/piskvorky/smart_open/issues/760#issuecomment-1553971657
    result = file_obj
    if "b" in mode and "w" in mode:
        result = io.BufferedWriter(result)
    elif "b" in mode and "r" in mode:
        result = io.BufferedReader(result)
    return result


def _handle_bz2(file_obj, mode):
    import bz2
    result = bz2.open(filename=file_obj, mode=mode)
    return _maybe_wrap_buffered(result, mode)


def _handle_gzip(file_obj, mode):
    import gzip
    result = gzip.open(filename=file_obj, mode=mode)
    return _maybe_wrap_buffered(result, mode)


def _handle_zstd(file_obj, mode):
    import sys
    if sys.version_info >= (3, 14):
        from compression import zstd
    else:
        from backports import zstd
    result = zstd.open(file_obj, mode=mode)
    return _maybe_wrap_buffered(result, mode)


def _handle_xz(file_obj, mode):
    import lzma
    result = lzma.open(filename=file_obj, mode=mode)
    return _maybe_wrap_buffered(result, mode)


def compression_wrapper(file_obj, mode, compression=INFER_FROM_EXTENSION, filename=None):
    """
    Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension.

    If the filename extension isn't recognized, simply return the original `file_obj` unchanged.

    `file_obj` must either be a filehandle object, or a class which behaves like one.

    If `filename` is specified, it will be used to extract the extension.
    If not, the `file_obj.name` attribute is used as the filename.

    """
    if compression == NO_COMPRESSION:
        return file_obj
    elif compression == INFER_FROM_EXTENSION:
        try:
            filename = (filename or file_obj.name).lower()
        except (AttributeError, TypeError):
            logger.warning(
                'unable to transparently decompress %r because it '
                'seems to lack a string-like .name', file_obj
            )
            return file_obj
        _, compression = os.path.splitext(filename)

    if compression in _COMPRESSOR_REGISTRY and mode.endswith('+'):
        raise ValueError('transparent (de)compression unsupported for mode %r' % mode)

    try:
        callback = _COMPRESSOR_REGISTRY[compression]
    except KeyError:
        return file_obj
    else:
        return callback(file_obj, mode)


#
# NB. avoid using lambda here to make stack traces more readable.
#
register_compressor('.bz2', _handle_bz2)
register_compressor('.gz', _handle_gzip)
register_compressor('.zst', _handle_zstd)
register_compressor('.xz', _handle_xz)