File: metadata.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (67 lines) | stat: -rw-r--r-- 2,194 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from dataclasses import dataclass, field
from typing import Dict, List, Union, Optional, Sequence, Any
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties

import torch
from torch.distributed._shard.sharded_tensor import (
    ShardedTensor,
)

@dataclass
class ChunkStorageMetadata:
    """
    Each chunk is expected to have the same properties of the TensorStorageMetadata that includes it.
    """
    offsets: torch.Size
    sizes: torch.Size

@dataclass
class TensorStorageMetadata:
    properties: TensorProperties
    size: torch.Size
    chunks: List[ChunkStorageMetadata]

@dataclass
class BytesStorageMetadata:
    pass

TENSOR_TYPE = Union[torch.Tensor, ShardedTensor]
STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
STATE_DICT_TYPE = Dict[str, Any]

@dataclass
class Metadata:
    # Keys are the same from the `state_dict` used.
    state_dict_metadata: Dict[str, STORAGE_TYPES]
    planner_data: Any = None
    storage_data: Any = None

@dataclass(frozen=True)
class MetadataIndex:
    """
    This class represents a lookup key for items in a state dict or Metadata.
    """
    fqn: str
    """Fully Qualified Name of the object"""

    offset: Optional[torch.Size] = None
    """If the object is a tensor, offset into the tensor we're looking for"""

    index: Optional[int] = field(hash=False, compare=False, default=None)
    """
    Index hint when searching for tensor chunk to speedup lookups (optional)

    A common representation of a sharded tensor is as a list of chunks so to
    find the index in such a list you need to linear search it.

    When constructing an instance of MetadataIndex that points to that list,
    one can provide the index as a hint and it will be probed first before
    the linear search and thus making it significantly faster.
    """

    def __init__(self, fqn: str, offset: Optional[Sequence[int]] = None, index: Optional[int] = None):
        # We must use object.__setattr__ due to frozen=True
        object.__setattr__(self, "fqn", fqn)
        object.__setattr__(self, "index", index)
        if offset is not None:
            object.__setattr__(self, "offset", torch.Size(offset))