File: delayed_mul_tensor.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (76 lines) | stat: -rw-r--r-- 2,377 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch

from . import _Tensor, Tensor
from .reference import _dims, _enable_layers, llist, ltuple


class DelayedMulTensor(_Tensor):
    def __init__(self, lhs, rhs):
        self._lhs, self._rhs = lhs, rhs
        self._data = None
        self._levels_data = None
        self._has_device = lhs._has_device or rhs._has_device
        self._batchtensor_data = None
        self._tensor_data = None

    @property
    def _levels(self):
        if self._levels_data is None:
            levels = llist(self._lhs._levels)
            for l in self._rhs._levels:
                if l not in levels:
                    levels.append(l)
            self._levels_data = ltuple(levels)
        return self._levels_data

    @property
    def _batchtensor(self):
        if self._batchtensor_data is None:
            with _enable_layers(self._levels):
                print("bt multiply fallback")
                self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
        return self._batchtensor_data

    @property
    def _tensor(self):
        if self._tensor_data is None:
            self._tensor_data = Tensor.from_batched(
                self._batchtensor, self._has_device
            )._tensor
        return self._tensor_data

    @property
    def ndim(self):
        return self._batchtensor.ndim

    @property
    def dims(self):
        return ltuple(super().dims)

    def sum(self, dim):
        dims = _dims(dim, 0, False, False)
        n = ord("a")
        all_levels = self._levels

        def to_char(d):
            return chr(n + all_levels.index(d))

        plhs, levelslhs = self._lhs._tensor, self._lhs._levels
        prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
        new_levels = [l for l in self._levels if l not in dims]
        fmt = "".join(
            [
                *(to_char(d) for d in levelslhs),
                ",",
                *(to_char(d) for d in levelsrhs),
                "->",
                *(to_char(d) for d in new_levels),
            ]
        )
        result_data = torch.einsum(fmt, (plhs, prhs))
        return Tensor.from_positional(result_data, new_levels, True)