File: _common.py

package info (click to toggle)
python-sparse 0.17.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,816 kB
  • sloc: python: 11,223; sh: 54; javascript: 10; makefile: 8
file content (65 lines) | stat: -rw-r--r-- 1,681 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import ctypes
import functools
import weakref
from collections.abc import Iterable

import mlir_finch.runtime as rt

import numpy as np

from ._core import libc
from ._dtypes import DType, asdtype


def fn_cache(f, maxsize: int | None = None):
    return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))


def get_nd_memref_descr(rank: int, dtype: DType) -> ctypes.Structure:
    return _get_nd_memref_descr(int(rank), asdtype(dtype))


@fn_cache
def _get_nd_memref_descr(rank: int, dtype: DType) -> ctypes.Structure:
    return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())


def numpy_to_ranked_memref(arr: np.ndarray) -> ctypes.Structure:
    memref = rt.get_ranked_memref_descriptor(arr)
    memref_descr = get_nd_memref_descr(arr.ndim, asdtype(arr.dtype))
    # Required due to ctypes type checks
    return memref_descr(
        allocated=memref.allocated,
        aligned=memref.aligned,
        offset=memref.offset,
        shape=memref.shape,
        strides=memref.strides,
    )


def ranked_memref_to_numpy(ref: ctypes.Structure) -> np.ndarray:
    return rt.ranked_memref_to_numpy([ref])


def free_memref(obj: ctypes.Structure) -> None:
    libc.free(ctypes.cast(obj.allocated, ctypes.c_void_p))


def _hold_ref(owner, obj):
    ptr = ctypes.py_object(obj)
    ctypes.pythonapi.Py_IncRef(ptr)

    def finalizer(ptr):
        ctypes.pythonapi.Py_DecRef(ptr)

    weakref.finalize(owner, finalizer, ptr)


def as_shape(x) -> tuple[int]:
    if not isinstance(x, Iterable):
        x = (x,)

    if not all(isinstance(xi, int) for xi in x):
        raise TypeError("Shape must be an `int` or tuple of `int`s.")

    return tuple(int(xi) for xi in x)