1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
|
# Copyright (C) 2024 Garth N. Wells
#
# This file is part of DOLFINx (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utility functions for calling PETSc C functions from Numba functions."""
from __future__ import annotations
import ctypes as _ctypes
import os
import pathlib
import warnings
import numpy as np
__all__ = ["cffi_utils", "numba_utils", "ctypes_utils"]
def get_petsc_lib() -> pathlib.Path:
"""Find the full path of the PETSc shared library.
Returns:
Full path to the PETSc shared library.
Raises:
RuntimeError: If PETSc library cannot be found for if more than
one library is found.
"""
import numpy as _np
import petsc4py as _petsc4py
petsc_dir = _petsc4py.get_config()["PETSC_DIR"]
petsc_arch = _petsc4py.lib.getPathArchPETSc()[1]
try:
if _np.issubdtype(_petsc4py.PETSc.ScalarType, _np.complexfloating):
scalar_type = "complex"
else:
scalar_type = "real"
except AttributeError:
# if petsc4py.PETSc is not available, read type from petsc_dir
scalar_type = "complex" if "complex" in petsc_dir else "real"
candidate_paths = [
os.path.join(petsc_dir, petsc_arch, "lib", f"libpetsc_{scalar_type}.so"),
os.path.join(petsc_dir, petsc_arch, "lib", "libpetsc.so"),
os.path.join(petsc_dir, petsc_arch, "lib", "libpetsc.dylib"),
]
exists_paths = []
for candidate_path in candidate_paths:
if os.path.exists(candidate_path):
exists_paths.append(candidate_path)
if len(exists_paths) == 0:
raise RuntimeError(
f"Could not find a PETSc shared library. Candidate paths: {candidate_paths}"
)
elif len(exists_paths) > 1:
raise RuntimeError(f"More than one PETSc shared library found. Paths: {exists_paths}")
return pathlib.Path(exists_paths[0])
class numba_utils:
"""Utility attributes for working with Numba and PETSc.
These attributes are convenience functions for calling PETSc C
functions from within Numba functions.
Note:
`Numba <https://numba.pydata.org/>`_ must be available
to use these utilities.
Examples:
A typical use of these utility functions is::
import numpy as np
import numpy.typing as npt
def set_vals(A: int,
m: int, rows: npt.NDArray[PETSc.IntType],
n: int, cols: npt.NDArray[PETSc.IntType],
data: npt.NDArray[PETSc.ScalarTYpe], mode: int):
MatSetValuesLocal(A, m, rows.ctypes, n, cols.ctypes, data.ctypes, mode)
"""
try:
import petsc4py.PETSc as _PETSc
import llvmlite as _llvmlite
import numba as _numba
_llvmlite.binding.load_library_permanently(str(get_petsc_lib()))
_int = _numba.from_dtype(_PETSc.IntType) # type: ignore
_scalar = _numba.from_dtype(_PETSc.ScalarType) # type: ignore
_real = _numba.from_dtype(_PETSc.RealType) # type: ignore
_int_ptr = _numba.core.types.CPointer(_int)
_scalar_ptr = _numba.core.types.CPointer(_scalar)
_MatSetValues_sig = _numba.core.typing.signature(
_numba.core.types.intc,
_numba.core.types.uintp,
_int,
_int_ptr,
_int,
_int_ptr,
_scalar_ptr,
_numba.core.types.intc,
)
MatSetValuesLocal = _numba.core.types.ExternalFunction(
"MatSetValuesLocal", _MatSetValues_sig
)
"""See PETSc `MatSetValuesLocal
<https://petsc.org/release/manualpages/Mat/MatSetValuesLocal>`_
documentation."""
MatSetValuesBlockedLocal = _numba.core.types.ExternalFunction(
"MatSetValuesBlockedLocal", _MatSetValues_sig
)
"""See PETSc `MatSetValuesBlockedLocal
<https://petsc.org/release/manualpages/Mat/MatSetValuesBlockedLocal>`_
documentation."""
except ImportError:
pass
class ctypes_utils:
"""Utility attributes for working with ctypes and PETSc.
These attributes are convenience functions for calling PETSc C
functions, typically from within Numba functions.
Examples:
A typical use of these utility functions is::
import numpy as np
import numpy.typing as npt
def set_vals(A: int,
m: int, rows: npt.NDArray[PETSc.IntType],
n: int, cols: npt.NDArray[PETSc.IntType],
data: npt.NDArray[PETSc.ScalarTYpe], mode: int):
MatSetValuesLocal(A, m, rows.ctypes, n, cols.ctypes, data.ctypes, mode)
"""
try:
import petsc4py.PETSc as _PETSc
_lib_ctypes = _ctypes.cdll.LoadLibrary(str(get_petsc_lib()))
# Note: ctypes does not have complex types, hence we use void* for
# scalar data
_int = np.ctypeslib.as_ctypes_type(_PETSc.IntType) # type: ignore
MatSetValuesLocal = _lib_ctypes.MatSetValuesLocal
"""See PETSc `MatSetValuesLocal
<https://petsc.org/release/manualpages/Mat/MatSetValuesLocal>`_
documentation."""
MatSetValuesLocal.argtypes = [
_ctypes.c_void_p,
_int,
_ctypes.POINTER(_int),
_int,
_ctypes.POINTER(_int),
_ctypes.c_void_p,
_ctypes.c_int,
]
MatSetValuesBlockedLocal = _lib_ctypes.MatSetValuesBlockedLocal
"""See PETSc `MatSetValuesBlockedLocal
<https://petsc.org/release/manualpages/Mat/MatSetValuesBlockedLocal>`_
documentation."""
MatSetValuesBlockedLocal.argtypes = [
_ctypes.c_void_p,
_int,
_ctypes.POINTER(_int),
_int,
_ctypes.POINTER(_int),
_ctypes.c_void_p,
_ctypes.c_int,
]
except ImportError:
pass
class cffi_utils:
"""Utility attributes for working with CFFI (ABI mode) and Numba.
Registers Numba's complex types with CFFI.
If PETSc is available, CFFI convenience functions for calling PETSc C
functions are also created. These are typically called from within Numba
functions.
Note:
`CFFI <https://cffi.readthedocs.io/>`_ and `Numba
<https://numba.pydata.org/>`_ must be available to use these utilities.
Examples:
A typical use of these utility functions is::
import numpy as np
import numpy.typing as npt
def set_vals(A: int,
m: int, rows: npt.NDArray[PETSc.IntType],
n: int, cols: npt.NDArray[PETSc.IntType],
data: npt.NDArray[PETSc.ScalarType], mode: int):
MatSetValuesLocal(A, m, ffi.from_buffer(rows), n, ffi.from_buffer(cols),
ffi.from_buffer(rows(data), mode)
"""
import cffi as _cffi
_ffi = _cffi.FFI()
try:
import numba as _numba
import numba.core.typing.cffi_utils as _cffi_support
# Register complex types
_cffi_support.register_type(_ffi.typeof("float _Complex"), _numba.types.complex64)
_cffi_support.register_type(_ffi.typeof("double _Complex"), _numba.types.complex128)
except KeyError:
pass
except ImportError:
warnings.warn(
"Could not import numba, so cffi/numba complex types were not registered.",
ImportWarning,
)
try:
from petsc4py import PETSc as _PETSc
_lib_cffi = _ffi.dlopen(str(get_petsc_lib()))
_CTYPES = {
np.int32: "int32_t",
np.int64: "int64_t",
np.float32: "float",
np.float64: "double",
np.complex64: "float _Complex",
np.complex128: "double _Complex",
np.longlong: "long long",
}
_c_int_t = _CTYPES[_PETSc.IntType] # type: ignore
_c_scalar_t = _CTYPES[_PETSc.ScalarType] # type: ignore
_ffi.cdef(
f"""
int MatSetValuesLocal(void* mat, {_c_int_t} nrow, const {_c_int_t}* irow,
{_c_int_t} ncol, const {_c_int_t}* icol,
const {_c_scalar_t}* y, int addv);
int MatSetValuesBlockedLocal(void* mat, {_c_int_t} nrow, const {_c_int_t}* irow,
{_c_int_t} ncol, const {_c_int_t}* icol,
const {_c_scalar_t}* y, int addv);
"""
)
MatSetValuesLocal = _lib_cffi.MatSetValuesLocal
"""See PETSc `MatSetValuesLocal
<https://petsc.org/release/manualpages/Mat/MatSetValuesLocal>`_
documentation."""
MatSetValuesBlockedLocal = _lib_cffi.MatSetValuesBlockedLocal
"""See PETSc `MatSetValuesBlockedLocal
<https://petsc.org/release/manualpages/Mat/MatSetValuesBlockedLocal>`_
documentation."""
except KeyError:
pass
except ImportError:
warnings.warn(
"Could not import petsc4py, so cffi/PETSc ABI mode interface was not created.",
ImportWarning,
)
|