1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
|
import functools
import six
from six.moves import reduce
from heapq import nsmallest
from operator import itemgetter, mul
import numpy
from .dtypes import dtype_to_ctype, _fill_dtype_registry
from .gpuarray import GpuArray
_fill_dtype_registry()
def as_argument(obj, name):
if isinstance(obj, GpuArray):
return ArrayArg(obj.dtype, name)
else:
return ScalarArg(numpy.asarray(obj).dtype, name)
class Argument(object):
def __init__(self, dtype, name):
self.dtype = dtype
self.name = name
def ctype(self):
return dtype_to_ctype(self.dtype)
def __hash__(self):
return hash(type(self)) ^ hash(self.dtype) ^ hash(self.name)
def __eq__(self, other):
return (type(self) == type(other) and
self.dtype == other.dtype and
self.name == other.name)
class ArrayArg(Argument):
def decltype(self):
return "GLOBAL_MEM {} *".format(self.ctype())
def expr(self):
return "{}[i]".format(self.name)
def isarray(self):
return True
def spec(self):
return GpuArray
class ScalarArg(Argument):
def decltype(self):
return self.ctype()
def expr(self):
return self.name
def isarray(self):
return False
def spec(self):
return self.dtype
def check_args(args, collapse=False, broadcast=False):
"""
Returns the properties of arguments and checks if they all match
(are all the same shape)
If `collapse` is True dimension collapsing will be performed.
If `collapse` is False dimension collapsing will not be performed.
If `broadcast` is True array broadcasting will be performed which
means that dimensions which are of size 1 in some arrays but not
others will be repeated to match the size of the other arrays.
If `broadcast` is False no broadcasting takes place.
"""
# For compatibility with old collapse=None option
if collapse is None:
collapse = True
strs = []
offsets = []
dims = None
for arg in args:
if isinstance(arg, GpuArray):
strs.append(arg.strides)
offsets.append(arg.offset)
if dims is None:
n, nd, dims = arg.size, arg.ndim, arg.shape
else:
if arg.ndim != nd:
raise ValueError("Array order differs")
if not broadcast and arg.shape != dims:
raise ValueError("Array shape differs")
else:
strs.append(None)
offsets.append(None)
if dims is None:
raise TypeError("No arrays in kernel arguments, "
"something is wrong")
tdims = dims
if broadcast or collapse:
# make the strides and dims editable
dims = list(dims)
strs = [list(str) if str is not None else str for str in strs]
if broadcast:
# Set strides to 0s when needed.
# Get the full shape in dims (no ones unless all arrays have it).
if 1 in dims:
for i, ary in enumerate(args):
if strs[i] is None:
continue
shp = ary.shape
for i, d in enumerate(shp):
if dims[i] != d and dims[i] == 1:
dims[i] = d
n *= d
tdims = tuple(dims)
for i, ary in enumerate(args):
if strs[i] is None:
continue
shp = ary.shape
if tdims != shp:
for j, d in enumerate(shp):
if dims[j] != d:
# Might want to add a per-dimension enable mechanism
if d == 1:
strs[i][j] = 0
else:
raise ValueError("Array shape differs")
if collapse and nd > 1:
# remove dimensions that are of size 1
for i in range(nd - 1, -1, -1):
if nd > 1 and dims[i] == 1:
del dims[i]
for str in strs:
if str is not None:
del str[i]
nd -= 1
# collapse contiguous dimensions
for i in range(nd - 1, 0, -1):
if all(str is None or str[i] * dims[i] == str[i - 1]
for str in strs):
dims[i - 1] *= dims[i]
del dims[i]
for str in strs:
if str is not None:
str[i - 1] = str[i]
del str[i]
nd -= 1
if broadcast or collapse:
# re-wrap dims and tuples
dims = tuple(dims)
strs = [tuple(str) if str is not None else None for str in strs]
return n, nd, dims, tuple(strs), tuple(offsets)
def lru_cache(maxsize=20):
def decorating_function(user_function):
cache = {}
last_use = {}
time = [0] # workaround for Python 2, which doesn't have nonlocal
@functools.wraps(user_function)
def wrapper(*key):
time[0] += 1
try:
result = cache[key]
wrapper.hits += 1
except KeyError:
result = user_function(*key)
cache[key] = result
wrapper.misses += 1
# purge least recently used cache entries
if len(cache) > wrapper.maxsize:
for key0, _ in nsmallest(wrapper.maxsize // 10,
six.iteritems(last_use),
key=itemgetter(1)):
del cache[key0], last_use[key0]
last_use[key] = time[0]
return result
def clear():
cache.clear()
last_use.clear()
wrapper.hits = wrapper.misses = 0
time[0] = 0
@functools.wraps(user_function)
def get(*key):
result = cache[key]
time[0] += 1
last_use[key] = time[0]
wrapper.hits += 1
return result
wrapper.hits = wrapper.misses = 0
wrapper.maxsize = maxsize
wrapper.clear = clear
wrapper.get = get
return wrapper
return decorating_function
def prod(iterable):
return reduce(mul, iterable, 1)
|