File: ops.py

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (24 lines) | stat: -rw-r--r-- 606 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from torch import Tensor


lib = torch.library._scoped_library("python_agnostic", "FRAGMENT")
lib.define("ultra_norm(Tensor[] inputs) -> Tensor")


def ultra_norm(inputs: list[Tensor]) -> Tensor:
    """
    Computes the ultra-L2-norm of a list of tensors via computing the norm of norms.

    Assumes:
    - inputs should not be empty
    - all tensors in inputs should be on the same device and have the same dtype

    Args:
        inputs: list of torch.tensors

    Returns:
        Scalar torch.tensor of shape ()

    """
    return torch.ops.python_agnostic.ultra_norm.default(inputs)