File: test_pruning_op.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (78 lines) | stat: -rw-r--r-- 3,625 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Owner(s): ["module: unknown"]

import hypothesis.strategies as st
from hypothesis import given
import numpy as np
import torch
from torch.testing._internal.common_utils import TestCase
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()


class PruningOpTest(TestCase):

    # Generate rowwise mask vector based on indicator and threshold value.
    # indicator is a vector that contains one value per weight row and it
    # represents the importance of a row.
    # We mask a row if its indicator value is less than the threshold.
    def _generate_rowwise_mask(self, embedding_rows):
        indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
        threshold = np.random.random_sample()
        mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
        return mask

    def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
        embedding_weights = None
        if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
            embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype)
        else:
            embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype)
        mask = self._generate_rowwise_mask(embedding_rows)

        def get_pt_result(embedding_weights, mask, indices_type):
            return torch._rowwise_prune(embedding_weights, mask, indices_type)

        # Reference implementation.
        def get_reference_result(embedding_weights, mask, indices_type):
            num_embeddings = mask.size()[0]
            compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type)
            pruned_weights_out = embedding_weights[mask[:]]
            idx = 0
            for i in range(mask.size()[0]):
                if mask[i]:
                    compressed_idx_out[i] = idx
                    idx = idx + 1
                else:
                    compressed_idx_out[i] = -1
            return (pruned_weights_out, compressed_idx_out)

        pt_pruned_weights, pt_compressed_indices_map = get_pt_result(
            embedding_weights, mask, indices_type)
        ref_pruned_weights, ref_compressed_indices_map = get_reference_result(
            embedding_weights, mask, indices_type)

        torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights)
        self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map)
        self.assertEqual(pt_compressed_indices_map.dtype, indices_type)


    @given(
        embedding_rows=st.integers(1, 100),
        embedding_dims=st.integers(1, 100),
        weights_dtype=st.sampled_from([torch.float64, torch.float32,
                                       torch.float16, torch.int8,
                                       torch.int16, torch.int32, torch.int64])
    )
    def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
        self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)


    @given(
        embedding_rows=st.integers(1, 100),
        embedding_dims=st.integers(1, 100),
        weights_dtype=st.sampled_from([torch.float64, torch.float32,
                                       torch.float16, torch.int8,
                                       torch.int16, torch.int32, torch.int64])
    )
    def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
        self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype)