File: mlir_pytaco_io.py

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (82 lines) | stat: -rw-r--r-- 2,688 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Experimental MLIR-PyTACO with sparse tensor support.

See http://tensor-compiler.org/ for TACO tensor compiler.

This module implements the PyTACO API for writing a tensor to a file or reading
a tensor from a file.

See the following links for Matrix Market Exchange (.mtx) format and FROSTT
(.tns) format:
  https://math.nist.gov/MatrixMarket/formats.html
  http://frostt.io/tensors/file-formats.html
"""

from typing import List, TextIO

from . import mlir_pytaco

# Define the type aliases so that we can write the implementation here as if
# it were part of mlir_pytaco.py.
Tensor = mlir_pytaco.Tensor
Format = mlir_pytaco.Format
DType = mlir_pytaco.DType
Type = mlir_pytaco.Type

# Constants used in the implementation.
_MTX_FILENAME_SUFFIX = ".mtx"
_TNS_FILENAME_SUFFIX = ".tns"


def read(filename: str, fmt: Format, dtype: DType = DType(Type.FLOAT32)) -> Tensor:
    """Inputs a tensor from a given file.

    The name suffix of the file specifies the format of the input tensor. We
    currently only support .mtx format for support sparse tensors.

    Args:
      filename: A string input filename.
      fmt: The storage format of the tensor.
      dtype: The data type, default to float32.

    Raises:
      ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
      instance of Format or fmt is not a sparse tensor.
    """
    if not isinstance(filename, str) or (
        not filename.endswith(_MTX_FILENAME_SUFFIX)
        and not filename.endswith(_TNS_FILENAME_SUFFIX)
    ):
        raise ValueError(
            "Expected string filename ends with "
            f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
            f"{filename}."
        )

    return Tensor.from_file(filename, fmt, dtype)


def write(filename: str, tensor: Tensor) -> None:
    """Outputs a tensor to a given file.

    The name suffix of the file specifies the format of the output. We currently
    only support .tns format.

    Args:
      filename: A string output filename.
      tensor: The tensor to output.

    Raises:
      ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
    """
    if not isinstance(filename, str) or not filename.endswith(_TNS_FILENAME_SUFFIX):
        raise ValueError(
            "Expected string filename ends with" f" {_TNS_FILENAME_SUFFIX}: {filename}."
        )
    if not isinstance(tensor, Tensor):
        raise ValueError(f"Expected a Tensor object: {tensor}.")

    tensor.to_file(filename)