File: index.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (826 lines) | stat: -rw-r--r-- 24,054 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
import functools
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Type,
    Union,
)

import torch
import torch.utils._pytree as pytree
from torch import Tensor

from torch_geometric.typing import INDEX_DTYPES

aten = torch.ops.aten

HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}


def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
    index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
    return index.repeat_interleave(ptr.diff(), output_size=output_size)


def index2ptr(index: Tensor, size: Optional[int] = None) -> Tensor:
    if size is None:
        size = int(index.max()) + 1 if index.numel() > 0 else 0

    return torch._convert_indices_from_coo_to_csr(
        index, size, out_int32=index.dtype != torch.int64)


class CatMetadata(NamedTuple):
    nnz: List[int]
    dim_size: List[Optional[int]]
    is_sorted: List[bool]


def implements(torch_function: Callable) -> Callable:
    r"""Registers a :pytorch:`PyTorch` function override."""
    @functools.wraps(torch_function)
    def decorator(my_function: Callable) -> Callable:
        HANDLED_FUNCTIONS[torch_function] = my_function
        return my_function

    return decorator


def assert_valid_dtype(tensor: Tensor) -> None:
    if tensor.dtype not in INDEX_DTYPES:
        raise ValueError(f"'Index' holds an unsupported data type "
                         f"(got '{tensor.dtype}', but expected one of "
                         f"{INDEX_DTYPES})")


def assert_one_dimensional(tensor: Tensor) -> None:
    if tensor.dim() != 1:
        raise ValueError(f"'Index' needs to be one-dimensional "
                         f"(got {tensor.dim()} dimensions)")


def assert_contiguous(tensor: Tensor) -> None:
    if not tensor.is_contiguous():
        raise ValueError("'Index' needs to be contiguous. Please call "
                         "`index.contiguous()` before proceeding.")


def assert_sorted(func: Callable) -> Callable:
    @functools.wraps(func)
    def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:
        if not self.is_sorted:
            cls_name = self.__class__.__name__
            raise ValueError(
                f"Cannot call '{func.__name__}' since '{cls_name}' is not "
                f"sorted. Please call `{cls_name}.sort()` first.")
        return func(self, *args, **kwargs)

    return wrapper


class Index(Tensor):
    r"""A one-dimensional :obj:`index` tensor with additional (meta)data
    attached.

    :class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds
    indices of shape :obj:`[num_indices]`.

    While :class:`Index` sub-classes a general :pytorch:`null`
    :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:

    * :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*,
      the size of a dimension that can be indexed via :obj:`index`.
      By default, it is inferred as :obj:`dim_size=index.max() + 1`.
    * :obj:`is_sorted`: Whether indices are sorted in ascending order.

    Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR
    conversion in case its representation is sorted.
    Caches are filled based on demand (*e.g.*, when calling
    :meth:`Index.get_indptr`), or when explicitly requested via
    :meth:`Index.fill_cache_`, and are maintaned and adjusted over its
    lifespan.

    This representation ensures for optimal computation in GNN message passing
    schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
    workflows.

    .. code-block:: python

        from torch_geometric import Index

        index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
        >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
        assert index.dim_size == 3
        assert index.is_sorted

        # Flipping order:
        edge_index.flip(0)
        >>> Index([[2, 1, 1, 0], dim_size=3)
        assert not index.is_sorted

        # Filtering:
        mask = torch.tensor([True, True, True, False])
        index[:, mask]
        >>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
        assert index.is_sorted
    """
    # See "https://pytorch.org/docs/stable/notes/extending.html"
    # for a basic tutorial on how to subclass `torch.Tensor`.

    # The underlying tensor representation:
    _data: Tensor

    # The size of the underlying sparse vector, e.g. `_data.max() + 1` :
    _dim_size: Optional[int] = None

    # Whether the `index` representation is sorted:
    _is_sorted: bool = False

    # A cache for its compressed representation:
    _indptr: Optional[Tensor] = None

    # Whenever we perform a concatenation of indices, we cache the original
    # metadata to be able to reconstruct individual indices:
    _cat_metadata: Optional[CatMetadata] = None

    @staticmethod
    def __new__(
        cls: Type,
        data: Any,
        *args: Any,
        dim_size: Optional[int] = None,
        is_sorted: bool = False,
        **kwargs: Any,
    ) -> 'Index':
        if not isinstance(data, Tensor):
            data = torch.tensor(data, *args, **kwargs)
        elif len(args) > 0:
            raise TypeError(
                f"new() received an invalid combination of arguments - got "
                f"(Tensor, {', '.join(str(type(arg)) for arg in args)})")
        elif len(kwargs) > 0:
            raise TypeError(f"new() received invalid keyword arguments - got "
                            f"{set(kwargs.keys())})")

        assert isinstance(data, Tensor)

        indptr: Optional[Tensor] = None

        if isinstance(data, cls):  # If passed `Index`, inherit metadata:
            indptr = data._indptr
            dim_size = dim_size or data.dim_size
            is_sorted = is_sorted or data.is_sorted

        assert_valid_dtype(data)
        assert_one_dimensional(data)
        assert_contiguous(data)

        out = Tensor._make_wrapper_subclass(  # type: ignore
            cls,
            size=data.size(),
            strides=data.stride(),
            dtype=data.dtype,
            device=data.device,
            layout=data.layout,
            requires_grad=False,
        )
        assert isinstance(out, Index)

        # Attach metadata:
        out._data = data
        out._dim_size = dim_size
        out._is_sorted = is_sorted
        out._indptr = indptr

        if isinstance(data, cls):
            out._data = data._data

            # Reset metadata if cache is invalidated:
            if dim_size is not None and dim_size != data.dim_size:
                out._indptr = None

        return out

    # Validation ##############################################################

    def validate(self) -> 'Index':
        r"""Validates the :class:`Index` representation.

        In particular, it ensures that

        * it only holds valid indices.
        * the sort order is correctly set.
        """
        assert_valid_dtype(self._data)
        assert_one_dimensional(self._data)
        assert_contiguous(self._data)

        if self.numel() > 0 and self._data.min() < 0:
            raise ValueError(f"'{self.__class__.__name__}' contains negative "
                             f"indices (got {int(self.min())})")

        if (self.numel() > 0 and self.dim_size is not None
                and self._data.max() >= self.dim_size):
            raise ValueError(f"'{self.__class__.__name__}' contains larger "
                             f"indices than its registered size "
                             f"(got {int(self._data.max())}, but expected "
                             f"values smaller than {self.dim_size})")

        if self.is_sorted and (self._data.diff() < 0).any():
            raise ValueError(f"'{self.__class__.__name__}' is not sorted")

        return self

    # Properties ##############################################################

    @property
    def dim_size(self) -> Optional[int]:
        r"""The size of the underlying sparse vector."""
        return self._dim_size

    @property
    def is_sorted(self) -> bool:
        r"""Returns whether indices are sorted in ascending order."""
        return self._is_sorted

    @property
    def dtype(self) -> torch.dtype:  # type: ignore
        # TODO Remove once PyTorch does not override `dtype` in `DataLoader`.
        return self._data.dtype

    # Cache Interface #########################################################

    def get_dim_size(self) -> int:
        r"""The size of the underlying sparse vector.
        Automatically computed and cached when not explicitly set.
        """
        if self._dim_size is None:
            dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0
            self._dim_size = dim_size

        assert isinstance(self._dim_size, int)
        return self._dim_size

    def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
        r"""Assigns or re-assigns the size of the underlying sparse vector."""
        if self.is_sorted and self._indptr is not None:
            if dim_size is None:
                self._indptr = None

            elif self._indptr.numel() - 1 >= dim_size:
                self._indptr = self._indptr[:dim_size + 1]

            else:
                fill_value = self._indptr.new_full(
                    (dim_size - self._indptr.numel() + 1, ),
                    fill_value=self._indptr[-1],  # type: ignore
                )
                self._indptr = torch.cat([self._indptr, fill_value], dim=0)

        self._dim_size = dim_size

        return self

    @assert_sorted
    def get_indptr(self) -> Tensor:
        r"""Returns the compressed index representation in case :class:`Index`
        is sorted.
        """
        if self._indptr is None:
            self._indptr = index2ptr(self._data, self.get_dim_size())

        assert isinstance(self._indptr, Tensor)
        return self._indptr

    def fill_cache_(self) -> 'Index':
        r"""Fills the cache with (meta)data information."""
        self.get_dim_size()

        if self.is_sorted:
            self.get_indptr()

        return self

    # Methods #################################################################

    def share_memory_(self) -> 'Index':
        """"""  # noqa: D419
        self._data.share_memory_()
        if self._indptr is not None:
            self._indptr.share_memory_()
        return self

    def is_shared(self) -> bool:
        """"""  # noqa: D419
        return self._data.is_shared()

    def as_tensor(self) -> Tensor:
        r"""Zero-copies the :class:`Index` representation back to a
        :class:`torch.Tensor` representation.
        """
        return self._data

    # PyTorch/Python builtins #################################################

    def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
        attrs = ['_data']
        if self._indptr is not None:
            attrs.append('_indptr')

        ctx = (
            self._dim_size,
            self._is_sorted,
            self._cat_metadata,
        )

        return attrs, ctx

    @staticmethod
    def __tensor_unflatten__(
        inner_tensors: Dict[str, Any],
        ctx: Tuple[Any, ...],
        outer_size: Tuple[int, ...],
        outer_stride: Tuple[int, ...],
    ) -> 'Index':
        index = Index(
            inner_tensors['_data'],
            dim_size=ctx[0],
            is_sorted=ctx[1],
        )

        index._indptr = inner_tensors.get('_indptr', None)
        index._cat_metadata = ctx[2]

        return index

    # Prevent auto-wrapping outputs back into the proper subclass type:
    __torch_function__ = torch._C._disabled_torch_function_impl

    @classmethod
    def __torch_dispatch__(
        cls: Type,
        func: Callable[..., Any],
        types: Iterable[Type[Any]],
        args: Iterable[Tuple[Any, ...]] = (),
        kwargs: Optional[Dict[Any, Any]] = None,
    ) -> Any:
        # `Index` should be treated as a regular PyTorch tensor for all
        # standard PyTorch functionalities. However,
        # * some of its metadata can be transferred to new functions, e.g.,
        #   `torch.narrow()` can inherit the `is_sorted` property.
        # * not all operations lead to valid `Index` tensors again, e.g.,
        #   `torch.sum()` does not yield a `Index` as its output, or
        #   `torch.stack() violates the [*] shape assumption.

        # To account for this, we hold a number of `HANDLED_FUNCTIONS` that
        # implement specific functions for valid `Index` routines.
        if func in HANDLED_FUNCTIONS:
            return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))

        # For all other PyTorch functions, we treat them as vanilla tensors.
        args = pytree.tree_map_only(Index, lambda x: x._data, args)
        if kwargs is not None:
            kwargs = pytree.tree_map_only(Index, lambda x: x._data, kwargs)
        return func(*args, **(kwargs or {}))

    def __repr__(self) -> str:  # type: ignore
        prefix = f'{self.__class__.__name__}('
        indent = len(prefix)
        tensor_str = torch._tensor_str._tensor_str(self._data, indent)

        suffixes = []
        if self.dim_size is not None:
            suffixes.append(f'dim_size={self.dim_size}')
        if (self.device.type != torch._C._get_default_device()
                or (self.device.type == 'cuda'
                    and torch.cuda.current_device() != self.device.index)
                or (self.device.type == 'mps')):
            suffixes.append(f"device='{self.device}'")
        if self.dtype != torch.int64:
            suffixes.append(f'dtype={self.dtype}')
        if self.is_sorted:
            suffixes.append('is_sorted=True')

        return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
                                               indent, force_newline=False)

    # Helpers #################################################################

    def _shallow_copy(self) -> 'Index':
        out = Index(self._data)
        out._dim_size = self._dim_size
        out._is_sorted = self._is_sorted
        out._indptr = self._indptr
        out._cat_metadata = self._cat_metadata
        return out

    def _clear_metadata(self) -> 'Index':
        self._dim_size = None
        self._is_sorted = False
        self._indptr = None
        self._cat_metadata = None
        return self


def apply_(
    tensor: Index,
    fn: Callable,
    *args: Any,
    **kwargs: Any,
) -> Union[Index, Tensor]:

    data = fn(tensor._data, *args, **kwargs)

    if data.dtype not in INDEX_DTYPES:
        return data

    if tensor._data.data_ptr() != data.data_ptr():
        out = Index(data)
    else:  # In-place:
        tensor._data = data
        out = tensor

    # Copy metadata:
    out._dim_size = tensor._dim_size
    out._is_sorted = tensor._is_sorted
    out._cat_metadata = tensor._cat_metadata

    # Convert cache:
    if tensor._indptr is not None:
        out._indptr = fn(tensor._indptr, *args, **kwargs)

    return out


@implements(aten.clone.default)
def _clone(
    tensor: Index,
    *,
    memory_format: torch.memory_format = torch.preserve_format,
) -> Index:
    out = apply_(tensor, aten.clone.default, memory_format=memory_format)
    assert isinstance(out, Index)
    return out


@implements(aten._to_copy.default)
def _to_copy(
    tensor: Index,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[torch.device] = None,
    pin_memory: bool = False,
    non_blocking: bool = False,
    memory_format: Optional[torch.memory_format] = None,
) -> Union[Index, Tensor]:
    return apply_(
        tensor,
        aten._to_copy.default,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        non_blocking=non_blocking,
        memory_format=memory_format,
    )


@implements(aten.alias.default)
def _alias(tensor: Index) -> Index:
    return tensor._shallow_copy()


@implements(aten._pin_memory.default)
def _pin_memory(tensor: Index) -> Index:
    out = apply_(tensor, aten._pin_memory.default)
    assert isinstance(out, Index)
    return out


@implements(aten.sort.default)
def _sort(
    tensor: Index,
    dim: int = -1,
    descending: bool = False,
) -> Tuple[Index, Tensor]:

    if tensor.is_sorted and not descending:
        return tensor, torch.arange(tensor._data.numel(),
                                    device=tensor._data.device)

    data, perm = aten.sort.default(tensor._data, dim, descending)

    out = Index(data)
    out._dim_size = tensor._dim_size

    if not descending:
        out._is_sorted = True

    return out, perm


@implements(aten.sort.stable)
def _sort_stable(
    tensor: Index,
    *,
    stable: bool = False,
    dim: int = -1,
    descending: bool = False,
) -> Tuple[Index, Tensor]:

    if tensor.is_sorted and not descending:
        return tensor, torch.arange(tensor._data.numel(),
                                    device=tensor._data.device)

    data, perm = aten.sort.stable(tensor._data, stable=stable, dim=dim,
                                  descending=descending)

    out = Index(data)
    out._dim_size = tensor._dim_size

    if not descending:
        out._is_sorted = True

    return out, perm


@implements(aten.cat.default)
def _cat(
    tensors: List[Union[Index, Tensor]],
    dim: int = 0,
) -> Union[Index, Tensor]:

    data_list = pytree.tree_map_only(Index, lambda x: x._data, tensors)
    data = aten.cat.default(data_list, dim=dim)

    if any([not isinstance(tensor, Index) for tensor in tensors]):
        return data

    out = Index(data)

    nnz_list = [t.numel() for t in tensors]
    dim_size_list = [t.dim_size for t in tensors]  # type: ignore
    is_sorted_list = [t.is_sorted for t in tensors]  # type: ignore

    # Post-process `dim_size`:
    total_dim_size: Optional[int] = 0
    for dim_size in dim_size_list:
        if dim_size is None:
            total_dim_size = None
            break
        assert isinstance(total_dim_size, int)
        total_dim_size = max(dim_size, total_dim_size)

    out._dim_size = total_dim_size

    out._cat_metadata = CatMetadata(
        nnz=nnz_list,
        dim_size=dim_size_list,
        is_sorted=is_sorted_list,
    )

    return out


@implements(aten.flip.default)
def _flip(
    input: Index,
    dims: Union[List[int], Tuple[int, ...]],
) -> Index:

    data = aten.flip.default(input._data, dims)

    out = Index(data)
    out._dim_size = input.dim_size

    return out


@implements(aten.index_select.default)
def _index_select(
    input: Union[Index, Tensor],
    dim: int,
    index: Union[Index, Tensor],
) -> Union[Index, Tensor]:

    out = aten.index_select.default(
        input._data if isinstance(input, Index) else input,
        dim,
        index._data if isinstance(index, Index) else index,
    )

    if isinstance(input, Index):
        out = Index(out)
        out._dim_size = input.dim_size

    return out


@implements(aten.slice.Tensor)
def _slice(
    input: Index,
    dim: int,
    start: Optional[int] = None,
    end: Optional[int] = None,
    step: int = 1,
) -> Index:

    if ((start is None or start <= 0)
            and (end is None or end > input.size(dim)) and step == 1):
        return input._shallow_copy()  # No-op.

    data = aten.slice.Tensor(input._data, dim, start, end, step)

    if step != 1:
        data = data.contiguous()

    out = Index(data)
    out._dim_size = input.dim_size
    # NOTE We could potentially maintain the `indptr` attribute here,
    # but it is not really clear if this is worth it. The most important
    # information `is_sorted` needs to be maintained though:
    if step >= 0:
        out._is_sorted = input.is_sorted

    return out


@implements(aten.index.Tensor)
def _index(
    input: Union[Index, Tensor],
    indices: List[Optional[Union[Tensor, Index]]],
) -> Union[Index, Tensor]:

    if not isinstance(input, Index):
        indices = pytree.tree_map_only(Index, lambda x: x._data, indices)
        return aten.index.Tensor(input, indices)

    data = aten.index.Tensor(input._data, indices)

    if data.dim() != 1:
        return data

    assert len(indices) == 1
    index = indices[0]
    assert index is not None

    out = Index(data)

    if index.dtype in (torch.bool, torch.uint8):  # 1. `index[mask]`.
        out._dim_size = input.dim_size
        out._is_sorted = input.is_sorted

    else:  # 2. `index[index]`.
        out._dim_size = input.dim_size

    return out


@implements(aten.add.Tensor)
def _add(
    input: Union[int, Tensor, Index],
    other: Union[int, Tensor, Index],
    *,
    alpha: int = 1,
) -> Union[Index, Tensor]:

    data = aten.add.Tensor(
        input._data if isinstance(input, Index) else input,
        other._data if isinstance(other, Index) else other,
        alpha=alpha,
    )

    if data.dtype not in INDEX_DTYPES:
        return data
    if data.dim() != 1:
        return data

    out = Index(data)

    if isinstance(input, Tensor) and input.numel() <= 1:
        input = int(input)

    if isinstance(other, Tensor) and other.numel() <= 1:
        other = int(other)

    if isinstance(other, int):
        assert isinstance(input, Index)
        if input.dim_size is not None:
            out._dim_size = input.dim_size + alpha * other
        out._is_sorted = input.is_sorted

    elif isinstance(input, int):
        assert isinstance(other, Index)
        if other.dim_size is not None:
            out._dim_size = input + alpha * other.dim_size
        out._is_sorted = other.is_sorted

    elif isinstance(input, Index) and isinstance(other, Index):
        if input.dim_size is not None and other.dim_size is not None:
            out._dim_size = input.dim_size + alpha * other.dim_size

    return out


@implements(aten.add_.Tensor)
def add_(
    input: Index,
    other: Union[int, Tensor, Index],
    *,
    alpha: int = 1,
) -> Index:

    dim_size = input.dim_size
    is_sorted = input.is_sorted
    input._clear_metadata()

    aten.add_.Tensor(
        input._data,
        other._data if isinstance(other, Index) else other,
        alpha=alpha,
    )

    if isinstance(other, Tensor) and other.numel() <= 1:
        other = int(other)

    if isinstance(other, int):
        if dim_size is not None:
            input._dim_size = dim_size + alpha * other
        input._is_sorted = is_sorted

    elif isinstance(other, Index):
        if dim_size is not None and other.dim_size is not None:
            input._dim_size = dim_size + alpha * other.dim_size

    return input


@implements(aten.sub.Tensor)
def _sub(
    input: Union[int, Tensor, Index],
    other: Union[int, Tensor, Index],
    *,
    alpha: int = 1,
) -> Union[Index, Tensor]:

    data = aten.sub.Tensor(
        input._data if isinstance(input, Index) else input,
        other._data if isinstance(other, Index) else other,
        alpha=alpha,
    )

    if data.dtype not in INDEX_DTYPES:
        return data
    if data.dim() != 1:
        return data

    out = Index(data)

    if not isinstance(input, Index):
        return out

    if isinstance(other, Tensor) and other.numel() <= 1:
        other = int(other)

    if isinstance(other, int):
        if input.dim_size is not None:
            out._dim_size = input.dim_size - alpha * other
        out._is_sorted = input.is_sorted

    return out


@implements(aten.sub_.Tensor)
def sub_(
    input: Index,
    other: Union[int, Tensor, Index],
    *,
    alpha: int = 1,
) -> Index:

    dim_size = input.dim_size
    is_sorted = input.is_sorted
    input._clear_metadata()

    aten.sub_.Tensor(
        input._data,
        other._data if isinstance(other, Index) else other,
        alpha=alpha,
    )

    if isinstance(other, Tensor) and other.numel() <= 1:
        other = int(other)

    if isinstance(other, int):
        if dim_size is not None:
            input._dim_size = dim_size - alpha * other
        input._is_sorted = is_sorted

    return input