1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
|
# mypy: allow-untyped-defs
from dataclasses import dataclass, field
from enum import Enum
from typing import List
import torch
from torch.distributed._shard.metadata import ShardMetadata
class MEM_FORMAT_ENCODING(Enum):
TORCH_CONTIGUOUS_FORMAT = 0
TORCH_CHANNELS_LAST = 1
TORCH_PRESERVE_FORMAT = 2
@dataclass
class TensorProperties:
"""Properties used to create :class:`Tensor`"""
# Regular tensor fields
dtype: torch.dtype = field(default=torch.get_default_dtype())
layout: torch.layout = field(default=torch.strided)
requires_grad: bool = False
memory_format: torch.memory_format = field(default=torch.contiguous_format)
pin_memory: bool = False
def __getstate__(self):
# Since torch.memory_format cannot be pickled!
memory_format = self.memory_format
if memory_format == torch.contiguous_format:
mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
elif memory_format == torch.channels_last:
mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
elif memory_format == torch.preserve_format:
mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
else:
raise RuntimeError(f"Invalid torch.memory_format: {memory_format}")
return (
self.dtype,
self.layout,
self.requires_grad,
mem_format_encoding,
self.pin_memory,
)
def __setstate__(
self,
state,
):
(
self.dtype,
self.layout,
self.requires_grad,
mem_format_encoding,
self.pin_memory,
) = state
if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
memory_format = torch.contiguous_format
elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
memory_format = torch.channels_last
elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
memory_format = torch.preserve_format
else:
raise RuntimeError(
f"Invalid torch.memory_format encoding: {mem_format_encoding}"
)
self.memory_format = memory_format
@staticmethod
def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
return TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
)
@dataclass
class ShardedTensorMetadata:
"""
Represents metadata for :class:`ShardedTensor`
"""
# Metadata about each shard of the Tensor
shards_metadata: List[ShardMetadata] = field(default_factory=list)
# Size of each dim of the overall Tensor.
size: torch.Size = field(default=torch.Size([]))
tensor_properties: TensorProperties = field(default_factory=TensorProperties)
|