File: test_bitmask.py

package info (click to toggle)
compressed-tensors 0.9.4-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 908 kB
  • sloc: python: 7,543; makefile: 32
file content (120 lines) | stat: -rw-r--r-- 4,294 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import shutil

import pytest
import torch
from compressed_tensors import BitmaskCompressor, BitmaskConfig, BitmaskTensor
from safetensors.torch import save_file


@pytest.mark.parametrize(
    "shape,sparsity,dtype",
    [
        [(512, 1024), 0.5, torch.float32],
        [(830, 545), 0.8, torch.float32],
        [(342, 512), 0.3, torch.bfloat16],
        [(256, 700), 0.9, torch.float16],
    ],
)
def test_bitmask_sizes(shape, sparsity, dtype):
    test_tensor = torch.rand(shape, dtype=dtype)
    mask = (test_tensor.abs() < (1 - sparsity)).int()
    test_tensor *= mask
    dense_state_dict = {"dummy.weight": test_tensor}

    sparsity_config = BitmaskConfig()
    compressor = BitmaskCompressor(config=sparsity_config)
    sparse_state_dict = compressor.compress(dense_state_dict)

    # each dense tensor has 4 parameters for compression
    assert len(dense_state_dict) * 4 == len(sparse_state_dict)

    # bitmask should be 1 bit per dense element, rounded up to nearest int8
    sparse_shape = sparse_state_dict["dummy.shape"]
    assert torch.all(torch.eq(sparse_shape, torch.tensor(shape)))
    bitmask_shape = sparse_state_dict["dummy.bitmask"].shape
    assert bitmask_shape[0] == sparse_shape[0]
    assert bitmask_shape[1] == int(math.ceil(sparse_shape[1] / 8.0))

    # one value for each non-zero weight
    values_shape = sparse_state_dict["dummy.compressed"].shape
    assert values_shape[0] == torch.sum(test_tensor != 0)
    row_offsets_shape = sparse_state_dict["dummy.row_offsets"].shape
    assert row_offsets_shape[0] == test_tensor.shape[0]


@pytest.mark.parametrize(
    "shape,sparsity,dtype",
    [
        [(256, 512), 0.5, torch.float32],
        [(128, 280), 0.8, torch.float32],
        [(1024, 256), 0.3, torch.bfloat16],
        [(511, 350), 0.7, torch.float16],
    ],
)
def test_match(shape, sparsity, dtype):
    test_tensor1 = torch.rand(shape, dtype=dtype)
    mask = (test_tensor1.abs() < (1 - sparsity)).int()
    test_tensor1 *= mask

    test_tensor2 = torch.rand(shape, dtype=dtype)
    mask = (test_tensor2.abs() < (1 - sparsity)).int()
    test_tensor2 *= mask

    dense_state_dict = {"dummy.weight": test_tensor1, "dummy2.weight": test_tensor2}

    for key in dense_state_dict.keys():
        dense_tensor = dense_state_dict[key]
        sparse_tensor = BitmaskTensor.from_dense(dense_tensor)
        decompressed = sparse_tensor.decompress()
        assert decompressed.dtype == dense_tensor.dtype == dtype
        assert torch.equal(dense_tensor, decompressed)


@pytest.mark.parametrize(
    "sparsity,dtype",
    [
        [0.5, torch.float32],
        [0.8, torch.float32],
        [0.3, torch.bfloat16],
        [0.7, torch.float16],
    ],
)
def test_reload_match(sparsity, dtype, tmp_path):
    test_tensor1 = torch.rand((256, 512), dtype=dtype)
    mask = (test_tensor1.abs() < (1 - sparsity)).int()
    test_tensor1 *= mask

    test_tensor2 = torch.rand((360, 720), dtype=dtype)
    mask = (test_tensor2.abs() < (1 - sparsity)).int()
    test_tensor2 *= mask

    dense_state_dict = {"dummy.weight": test_tensor1, "dummy2.weight": test_tensor2}

    sparsity_config = BitmaskConfig()
    compressor = BitmaskCompressor(config=sparsity_config)

    sparse_state_dict = compressor.compress(dense_state_dict)
    save_file(sparse_state_dict, tmp_path / "model.safetensors")
    reconstructed_dense = compressor.decompress(tmp_path)

    for key, reconstructed_tensor in reconstructed_dense:
        dense_tensor = dense_state_dict[key]
        assert dense_tensor.dtype == reconstructed_tensor.dtype == dtype
        assert torch.equal(dense_tensor, reconstructed_tensor)

    shutil.rmtree(tmp_path)