1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
|
# Optimized inner loop of load_svmlight_file.
#
# Authors: Mathieu Blondel <mathieu@mblondel.org>
# Lars Buitinck
# Olivier Grisel <olivier.grisel@ensta.org>
# License: BSD 3 clause
import array
from cpython cimport array
cimport cython
from libc.string cimport strchr
import numpy as np
cdef bytes COMMA = u','.encode('ascii')
cdef bytes COLON = u':'.encode('ascii')
def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
bint query_id, long long offset, long long length):
cdef array.array data, indices, indptr
cdef bytes line
cdef char *hash_ptr
cdef char *line_cstr
cdef int idx, prev_idx
cdef Py_ssize_t i
cdef bytes qid_prefix = b'qid'
cdef Py_ssize_t n_features
cdef long long offset_max = offset + length if length > 0 else -1
# Special-case float32 but use float64 for everything else;
# the Python code will do further conversions.
if dtype == np.float32:
data = array.array("f")
else:
dtype = np.float64
data = array.array("d")
indices = array.array("q")
indptr = array.array("q", [0])
query = np.arange(0, dtype=np.int64)
if multilabel:
labels = []
else:
labels = array.array("d")
if offset > 0:
f.seek(offset)
# drop the current line that might be truncated and is to be
# fetched by another call
f.readline()
for line in f:
# skip comments
line_cstr = line
hash_ptr = strchr(line_cstr, 35) # ASCII value of '#' is 35
if hash_ptr != NULL:
line = line[:hash_ptr - line_cstr]
line_parts = line.split()
if len(line_parts) == 0:
continue
target, features = line_parts[0], line_parts[1:]
if multilabel:
if COLON in target:
target, features = [], line_parts[0:]
else:
target = [float(y) for y in target.split(COMMA)]
target.sort()
labels.append(tuple(target))
else:
array.resize_smart(labels, len(labels) + 1)
labels[len(labels) - 1] = float(target)
prev_idx = -1
n_features = len(features)
if n_features and features[0].startswith(qid_prefix):
_, value = features[0].split(COLON, 1)
if query_id:
query.resize(len(query) + 1)
query[len(query) - 1] = np.int64(value)
features.pop(0)
n_features -= 1
for i in range(0, n_features):
idx_s, value = features[i].split(COLON, 1)
idx = int(idx_s)
if idx < 0 or not zero_based and idx == 0:
raise ValueError(
"Invalid index %d in SVMlight/LibSVM data file." % idx)
if idx <= prev_idx:
raise ValueError("Feature indices in SVMlight/LibSVM data "
"file should be sorted and unique.")
array.resize_smart(indices, len(indices) + 1)
indices[len(indices) - 1] = idx
array.resize_smart(data, len(data) + 1)
data[len(data) - 1] = float(value)
prev_idx = idx
# increment index pointer array size
array.resize_smart(indptr, len(indptr) + 1)
indptr[len(indptr) - 1] = len(data)
if offset_max != -1 and f.tell() > offset_max:
# Stop here and let another call deal with the following.
break
return (dtype, data, indices, indptr, labels, query)
# Two fused types are defined to be able to
# use all possible combinations of parameters.
ctypedef fused int_or_float:
cython.integral
cython.floating
signed long long
ctypedef fused double_or_longlong:
double
signed long long
ctypedef fused int_or_longlong:
cython.integral
signed long long
def get_dense_row_string(
int_or_float[:,:] X,
Py_ssize_t[:] x_inds,
double_or_longlong[:] x_vals,
Py_ssize_t row,
str value_pattern,
bint one_based,
):
cdef:
Py_ssize_t row_length = X.shape[1]
Py_ssize_t x_nz_used = 0
Py_ssize_t k
int_or_float val
for k in range(row_length):
val = X[row,k]
if val == 0:
continue
x_inds[x_nz_used] = k
x_vals[x_nz_used] = <double_or_longlong> val
x_nz_used += 1
reprs = [
value_pattern % (x_inds[i] + one_based, x_vals[i])
for i in range(x_nz_used)
]
return " ".join(reprs)
def get_sparse_row_string(
int_or_float[:] X_data,
int[:] X_indptr,
int[:] X_indices,
Py_ssize_t row,
str value_pattern,
bint one_based,
):
cdef:
Py_ssize_t row_start = X_indptr[row]
Py_ssize_t row_end = X_indptr[row+1]
reprs = [
value_pattern % (X_indices[i] + one_based, X_data[i])
for i in range(row_start, row_end)
]
return " ".join(reprs)
def _dump_svmlight_file(
X,
y,
f,
bint multilabel,
bint one_based,
int_or_longlong[:] query_id,
bint X_is_sp,
bint y_is_sp,
):
cdef bint X_is_integral
cdef bint query_id_is_not_empty = query_id.size > 0
X_is_integral = X.dtype.kind == "i"
if X_is_integral:
value_pattern = "%d:%d"
else:
value_pattern = "%d:%.16g"
if y.dtype.kind == "i":
label_pattern = "%d"
else:
label_pattern = "%.16g"
line_pattern = "%s"
if query_id_is_not_empty:
line_pattern += " qid:%d"
line_pattern += " %s\n"
cdef:
Py_ssize_t num_labels = y.shape[1]
Py_ssize_t x_len = X.shape[0]
Py_ssize_t row_length = X.shape[1]
Py_ssize_t i
Py_ssize_t j
Py_ssize_t col_start
Py_ssize_t col_end
bint first
Py_ssize_t x_nz_used
Py_ssize_t[:] x_inds = np.empty(row_length, dtype=np.intp)
signed long long[:] x_vals_int
double[:] x_vals_float
if not X_is_sp:
if X_is_integral:
x_vals_int = np.zeros(row_length, dtype=np.longlong)
else:
x_vals_float = np.zeros(row_length, dtype=np.float64)
for i in range(x_len):
x_nz_used = 0
if not X_is_sp:
if X_is_integral:
s = get_dense_row_string(X, x_inds, x_vals_int, i, value_pattern, one_based)
else:
s = get_dense_row_string(X, x_inds, x_vals_float, i, value_pattern, one_based)
else:
s = get_sparse_row_string(X.data, X.indptr, X.indices, i, value_pattern, one_based)
if multilabel:
first = True
if y_is_sp:
col_start = y.indptr[i]
col_end = y.indptr[i+1]
labels_str = ','.join(tuple(label_pattern % y.indices[j] for j in range(col_start, col_end) if y.data[j] != 0))
else:
labels_str = ','.join(label_pattern % j for j in range(num_labels) if y[i,j] != 0)
else:
if y_is_sp:
labels_str = label_pattern % y.data[i]
else:
labels_str = label_pattern % y[i,0]
if query_id_is_not_empty:
feat = (labels_str, query_id[i], s)
else:
feat = (labels_str, s)
f.write((line_pattern % feat).encode("utf-8"))
|