File: delayed_mul_tensor.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (67 lines) | stat: -rw-r--r-- 2,338 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import _Tensor, Tensor
from .reference import _dims, _enable_layers, llist, ltuple

class DelayedMulTensor(_Tensor):
    def __init__(self, lhs, rhs):
        self._lhs, self._rhs = lhs, rhs
        self._data = None
        self._levels_data = None
        self._has_device = lhs._has_device or rhs._has_device
        self._batchtensor_data = None
        self._tensor_data = None

    @property
    def _levels(self):
        if self._levels_data is None:
            levels = llist(self._lhs._levels)
            for l in self._rhs._levels:
                if l not in levels:
                    levels.append(l)
            self._levels_data = ltuple(levels)
        return self._levels_data

    @property
    def _batchtensor(self):
        if self._batchtensor_data is None:
            with _enable_layers(self._levels):
                print("bt multiply fallback")
                self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
        return self._batchtensor_data

    @property
    def _tensor(self):
        if self._tensor_data is None:
            self._tensor_data = Tensor.from_batched(self._batchtensor, self._has_device)._tensor
        return self._tensor_data

    @property
    def ndim(self):
        return self._batchtensor.ndim

    @property
    def dims(self):
        return ltuple(super().dims)


    def sum(self, dim):
        dims = _dims(dim, 0, False, False)
        n = ord('a')
        all_levels = self._levels

        def to_char(d):
            return chr(n + all_levels.index(d))
        plhs, levelslhs = self._lhs._tensor, self._lhs._levels
        prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
        new_dims = tuple(d for d in self.dims if d not in dims)
        new_levels = [l for l in self._levels if l not in dims]
        fmt = ''.join([*(to_char(d) for d in levelslhs), ',',
                       *(to_char(d) for d in levelsrhs), '->',
                       *(to_char(d) for d in new_levels)])
        result_data = torch.einsum(fmt, (plhs, prhs))
        return Tensor.from_positional(result_data, new_levels, True)