## Finch backend for `sparse`

<a href="https://colab.research.google.com/github/pydata/sparse/blob/main/examples/sparse_finch.ipynb" target="_blank">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" />
</a> to download and run.

In [None]:
#!pip install 'sparse[finch]==0.16.0a9' scipy
#!export SPARSE_BACKEND=Finch

# let's make sure we're using Finch backend
import os

os.environ["SPARSE_BACKEND"] = "Finch"
CI_MODE = bool(int(os.getenv("CI_MODE", default="0")))

In [None]:
import importlib
import time

import sparse

import matplotlib.pyplot as plt
import networkx as nx

import numpy as np
import scipy.sparse as sps
import scipy.sparse.linalg as splin

In [None]:
tns = sparse.asarray(np.zeros((10, 10)))  # offers a no-copy constructor for NumPy as scipy.sparse inputs

s1 = sparse.random((100, 10), density=0.01)  # creates random COO tensor
s2 = sparse.random((100, 100, 10), density=0.01)
s2 = sparse.asarray(s2, format="csf")  # can be used to rewrite tensor to a new format

result = sparse.tensordot(s1, s2, axes=([0, 1], [0, 2]))

total = sparse.sum(result * result)
print(total)

### Example: least squares - closed form

In [None]:
y = sparse.random((100, 1), density=0.08)
X = sparse.random((100, 5), density=0.08)
X = sparse.asarray(X, format="csc")
X_lazy = sparse.lazy(X)

X_X = sparse.compute(sparse.permute_dims(X_lazy, (1, 0)) @ X_lazy)

X_X = sparse.asarray(X_X, format="csc")  # move back from dense to CSC format

inverted = splin.inv(X_X)  # dispatching to scipy.sparse.sparray

b_hat = (inverted @ sparse.permute_dims(X, (1, 0))) @ y

print(b_hat.todense())

## Benchmark plots

In [None]:
ITERS = 1
rng = np.random.default_rng(0)

In [None]:
plt.style.use("seaborn-v0_8")
plt.rcParams["figure.dpi"] = 400
plt.rcParams["figure.figsize"] = [8, 4]

In [None]:
def benchmark(func, info, args) -> float:
    start = time.time()
    for _ in range(ITERS):
        func(*args)
    elapsed = time.time() - start
    return elapsed / ITERS

## MTTKRP

In [None]:
print("MTTKRP Example:\n")

os.environ[sparse._ENV_VAR_NAME] = "Numba"
importlib.reload(sparse)

configs = [
    {"I_": 100, "J_": 25, "K_": 100, "L_": 10, "DENSITY": 0.001},
    {"I_": 100, "J_": 25, "K_": 100, "L_": 100, "DENSITY": 0.001},
    {"I_": 1000, "J_": 25, "K_": 100, "L_": 100, "DENSITY": 0.001},
    {"I_": 1000, "J_": 25, "K_": 1000, "L_": 100, "DENSITY": 0.001},
    {"I_": 1000, "J_": 25, "K_": 1000, "L_": 1000, "DENSITY": 0.001},
]
nonzeros = [100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]

if CI_MODE:
    configs = configs[:1]
    nonzeros = nonzeros[:1]

finch_times = []
numba_times = []
finch_galley_times = []

for config in configs:
    B_shape = (config["I_"], config["K_"], config["L_"])
    B_sps = sparse.random(B_shape, density=config["DENSITY"], random_state=rng)
    D_sps = rng.random((config["L_"], config["J_"]))
    C_sps = rng.random((config["K_"], config["J_"]))

    # ======= Finch =======
    os.environ[sparse._ENV_VAR_NAME] = "Finch"
    importlib.reload(sparse)

    B = sparse.asarray(B_sps.todense(), format="csf")
    D = sparse.asarray(np.array(D_sps, order="F"))
    C = sparse.asarray(np.array(C_sps, order="F"))

    @sparse.compiled(opt=sparse.DefaultScheduler())
    def mttkrp_finch(B, D, C):
        return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))

    # Compile
    result_finch = mttkrp_finch(B, D, C)
    # Benchmark
    time_finch = benchmark(mttkrp_finch, info="Finch", args=[B, D, C])

    # ======= Finch Galley =======
    os.environ[sparse._ENV_VAR_NAME] = "Finch"
    importlib.reload(sparse)

    B = sparse.asarray(B_sps.todense(), format="csf")
    D = sparse.asarray(np.array(D_sps, order="F"))
    C = sparse.asarray(np.array(C_sps, order="F"))

    @sparse.compiled(opt=sparse.GalleyScheduler(), tag=sum(B_shape))
    def mttkrp_finch_galley(B, D, C):
        return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))

    # Compile
    result_finch_galley = mttkrp_finch_galley(B, D, C)
    # Benchmark
    time_finch_galley = benchmark(mttkrp_finch_galley, info="Finch Galley", args=[B, D, C])

    # ======= Numba =======
    os.environ[sparse._ENV_VAR_NAME] = "Numba"
    importlib.reload(sparse)

    B = sparse.asarray(B_sps, format="gcxs")
    D = D_sps
    C = C_sps

    def mttkrp_numba(B, D, C):
        return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))

    # Compile
    result_numba = mttkrp_numba(B, D, C)
    # Benchmark
    time_numba = benchmark(mttkrp_numba, info="Numba", args=[B, D, C])

    np.testing.assert_allclose(result_finch.todense(), result_numba.todense())

    finch_times.append(time_finch)
    numba_times.append(time_numba)
    finch_galley_times.append(time_finch_galley)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1)

ax.plot(nonzeros, finch_times, "o-", label="Finch")
ax.plot(nonzeros, numba_times, "o-", label="Numba")
ax.plot(nonzeros, finch_galley_times, "o-", label="Finch - Galley")
ax.grid(True)
ax.set_xlabel("no. of elements")
ax.set_ylabel("time (sec)")
ax.set_title("MTTKRP")
ax.set_xscale("log")
ax.set_yscale("log")
ax.legend(loc="best", numpoints=1)

plt.show()

## SDDMM

In [None]:
print("SDDMM Example:\n")

configs = [
    {"LEN": 5000, "DENSITY": 0.00001},
    {"LEN": 10000, "DENSITY": 0.00001},
    {"LEN": 15000, "DENSITY": 0.00001},
    {"LEN": 20000, "DENSITY": 0.00001},
    {"LEN": 25000, "DENSITY": 0.00001},
    {"LEN": 30000, "DENSITY": 0.00001},
]
size_n = [5000, 10000, 15000, 20000, 25000, 30000]

if CI_MODE:
    configs = configs[:1]
    size_n = size_n[:1]

finch_times = []
numba_times = []
scipy_times = []
finch_galley_times = []

for config in configs:
    LEN = config["LEN"]
    DENSITY = config["DENSITY"]

    a_sps = rng.random((LEN, LEN))
    b_sps = rng.random((LEN, LEN))
    s_sps = sps.random(LEN, LEN, format="coo", density=DENSITY, random_state=rng)
    s_sps.sum_duplicates()

    # ======= Finch =======
    print("finch")
    os.environ[sparse._ENV_VAR_NAME] = "Finch"
    importlib.reload(sparse)

    s = sparse.asarray(s_sps)
    a = sparse.asarray(a_sps)
    b = sparse.asarray(b_sps)

    @sparse.compiled(opt=sparse.DefaultScheduler())
    def sddmm_finch(s, a, b):
        return s * (a @ b)

    # Compile
    result_finch = sddmm_finch(s, a, b)
    # Benchmark
    time_finch = benchmark(sddmm_finch, info="Finch", args=[s, a, b])

    # ======= Finch Galley =======
    print("finch galley")
    os.environ[sparse._ENV_VAR_NAME] = "Finch"
    importlib.reload(sparse)

    s = sparse.asarray(s_sps)
    a = sparse.asarray(a_sps)
    b = sparse.asarray(b_sps)

    @sparse.compiled(opt=sparse.GalleyScheduler(), tag=LEN)
    def sddmm_finch_galley(s, a, b):
        return s * (a @ b)

    # Compile
    result_finch_galley = sddmm_finch_galley(s, a, b)
    # Benchmark
    time_finch_galley = benchmark(sddmm_finch_galley, info="Finch Galley", args=[s, a, b])

    # ======= Numba =======
    print("numba")
    os.environ[sparse._ENV_VAR_NAME] = "Numba"
    importlib.reload(sparse)

    s = sparse.asarray(s_sps)
    a = a_sps
    b = b_sps

    def sddmm_numba(s, a, b):
        return s * (a @ b)

    # Compile
    result_numba = sddmm_numba(s, a, b)
    # Benchmark
    time_numba = benchmark(sddmm_numba, info="Numba", args=[s, a, b])

    # ======= SciPy =======
    print("scipy")

    def sddmm_scipy(s, a, b):
        return s.multiply(a @ b)

    s = s_sps.asformat("csr")
    a = a_sps
    b = b_sps

    result_scipy = sddmm_scipy(s, a, b)
    # Benchmark
    time_scipy = benchmark(sddmm_scipy, info="SciPy", args=[s, a, b])

    finch_times.append(time_finch)
    numba_times.append(time_numba)
    scipy_times.append(time_scipy)
    finch_galley_times.append(time_finch_galley)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1)

ax.plot(size_n, finch_times, "o-", label="Finch")
ax.plot(size_n, numba_times, "o-", label="Numba")
ax.plot(size_n, scipy_times, "o-", label="SciPy")
ax.plot(size_n, finch_galley_times, "o-", label="Finch Galley")

ax.grid(True)
ax.set_xlabel("size N")
ax.set_ylabel("time (sec)")
ax.set_title("SDDMM")
ax.legend(loc="best", numpoints=1)

plt.show()

## Counting Triangles

In [None]:
print("Counting Triangles Example:\n")

configs = [
    {"LEN": 10000, "DENSITY": 0.001},
    {"LEN": 15000, "DENSITY": 0.001},
    {"LEN": 20000, "DENSITY": 0.001},
    {"LEN": 25000, "DENSITY": 0.001},
    {"LEN": 30000, "DENSITY": 0.001},
    {"LEN": 35000, "DENSITY": 0.001},
    {"LEN": 40000, "DENSITY": 0.001},
    {"LEN": 45000, "DENSITY": 0.001},
    {"LEN": 50000, "DENSITY": 0.001},
]
size_n = [10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000]

if CI_MODE:
    configs = configs[:1]
    size_n = size_n[:1]

finch_times = []
finch_galley_times = []
networkx_times = []
scipy_times = []

for config in configs:
    LEN = config["LEN"]
    DENSITY = config["DENSITY"]

    G = nx.gnp_random_graph(n=LEN, p=DENSITY)
    a_sps = nx.to_scipy_sparse_array(G)

    # ======= Finch =======
    print("finch")
    os.environ[sparse._ENV_VAR_NAME] = "Finch"
    importlib.reload(sparse)

    a = sparse.asarray(a_sps)

    @sparse.compiled(opt=sparse.DefaultScheduler())
    def ct_finch(a):
        return sparse.sum(a @ a * a) / sparse.asarray(6)

    # Compile
    result_finch = ct_finch(a)
    # Benchmark
    time_finch = benchmark(ct_finch, info="Finch", args=[a])

    # ======= Finch Galley =======
    print("finch galley")
    os.environ[sparse._ENV_VAR_NAME] = "Finch"
    importlib.reload(sparse)

    a = sparse.asarray(a_sps)

    @sparse.compiled(opt=sparse.GalleyScheduler(), tag=LEN)
    def ct_finch_galley(a):
        return sparse.sum(a @ a * a) / sparse.asarray(6)

    # Compile
    result_finch_galley = ct_finch_galley(a)
    # Benchmark
    time_finch_galley = benchmark(ct_finch_galley, info="Finch Galley", args=[a])

    # ======= SciPy =======
    print("scipy")

    def ct_scipy(a):
        return (a @ a * a).sum() / 6

    a = a_sps

    # Benchmark
    time_scipy = benchmark(ct_scipy, info="SciPy", args=[a])

    # ======= NetworkX =======
    print("networkx")

    def ct_networkx(a):
        return sum(nx.triangles(a).values()) / 3

    a = G

    time_networkx = benchmark(ct_networkx, info="SciPy", args=[a])

    finch_times.append(time_finch)
    finch_galley_times.append(time_finch_galley)
    networkx_times.append(time_networkx)
    scipy_times.append(time_scipy)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1)

ax.plot(size_n, finch_times, "o-", label="Finch")
ax.plot(size_n, networkx_times, "o-", label="NetworkX")
ax.plot(size_n, scipy_times, "o-", label="SciPy")
ax.plot(size_n, finch_galley_times, "o-", label="Finch Galley")

ax.grid(True)
ax.set_xlabel("size N")
ax.set_ylabel("time (sec)")
ax.set_title("Counting Triangles")
ax.legend(loc="best", numpoints=1)

plt.show()