1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
|
/*
* This file is part of the micropython-ulab project,
*
* https://github.com/v923z/micropython-ulab
*
* The MIT License (MIT)
*
* Copyright (c) 2020-2021 Zoltán Vörös
*/
#include <string.h>
#include "py/runtime.h"
#include "ulab.h"
#include "ndarray.h"
#include "ulab_tools.h"
// The following five functions return a float from a void type
// The value in question is supposed to be located at the head of the pointer
mp_float_t ndarray_get_float_uint8(void *data) {
// Returns a float value from an uint8_t type
return (mp_float_t)(*(uint8_t *)data);
}
mp_float_t ndarray_get_float_int8(void *data) {
// Returns a float value from an int8_t type
return (mp_float_t)(*(int8_t *)data);
}
mp_float_t ndarray_get_float_uint16(void *data) {
// Returns a float value from an uint16_t type
return (mp_float_t)(*(uint16_t *)data);
}
mp_float_t ndarray_get_float_int16(void *data) {
// Returns a float value from an int16_t type
return (mp_float_t)(*(int16_t *)data);
}
mp_float_t ndarray_get_float_float(void *data) {
// Returns a float value from an mp_float_t type
return *((mp_float_t *)data);
}
// returns a single function pointer, depending on the dtype
void *ndarray_get_float_function(uint8_t dtype) {
if(dtype == NDARRAY_UINT8) {
return ndarray_get_float_uint8;
} else if(dtype == NDARRAY_INT8) {
return ndarray_get_float_int8;
} else if(dtype == NDARRAY_UINT16) {
return ndarray_get_float_uint16;
} else if(dtype == NDARRAY_INT16) {
return ndarray_get_float_int16;
} else {
return ndarray_get_float_float;
}
}
mp_float_t ndarray_get_float_index(void *data, uint8_t dtype, size_t index) {
// returns a single float value from an array located at index
if(dtype == NDARRAY_UINT8) {
return (mp_float_t)((uint8_t *)data)[index];
} else if(dtype == NDARRAY_INT8) {
return (mp_float_t)((int8_t *)data)[index];
} else if(dtype == NDARRAY_UINT16) {
return (mp_float_t)((uint16_t *)data)[index];
} else if(dtype == NDARRAY_INT16) {
return (mp_float_t)((int16_t *)data)[index];
} else {
return (mp_float_t)((mp_float_t *)data)[index];
}
}
mp_float_t ndarray_get_float_value(void *data, uint8_t dtype) {
// Returns a float value from an arbitrary data type
// The value in question is supposed to be located at the head of the pointer
if(dtype == NDARRAY_UINT8) {
return (mp_float_t)(*(uint8_t *)data);
} else if(dtype == NDARRAY_INT8) {
return (mp_float_t)(*(int8_t *)data);
} else if(dtype == NDARRAY_UINT16) {
return (mp_float_t)(*(uint16_t *)data);
} else if(dtype == NDARRAY_INT16) {
return (mp_float_t)(*(int16_t *)data);
} else {
return *((mp_float_t *)data);
}
}
#if NDARRAY_BINARY_USES_FUN_POINTER | ULAB_NUMPY_HAS_WHERE
uint8_t ndarray_upcast_dtype(uint8_t ldtype, uint8_t rdtype) {
// returns a single character that corresponds to the broadcasting rules
// - if one of the operarands is a float, the result is always float
// - operation on identical types preserves type
//
// uint8 + int8 => int16
// uint8 + int16 => int16
// uint8 + uint16 => uint16
// int8 + int16 => int16
// int8 + uint16 => uint16
// uint16 + int16 => float
if(ldtype == rdtype) {
// if the two dtypes are equal, the result is also of that type
return ldtype;
} else if(((ldtype == NDARRAY_UINT8) && (rdtype == NDARRAY_INT8)) ||
((ldtype == NDARRAY_INT8) && (rdtype == NDARRAY_UINT8)) ||
((ldtype == NDARRAY_UINT8) && (rdtype == NDARRAY_INT16)) ||
((ldtype == NDARRAY_INT16) && (rdtype == NDARRAY_UINT8)) ||
((ldtype == NDARRAY_INT8) && (rdtype == NDARRAY_INT16)) ||
((ldtype == NDARRAY_INT16) && (rdtype == NDARRAY_INT8))) {
return NDARRAY_INT16;
} else if(((ldtype == NDARRAY_UINT8) && (rdtype == NDARRAY_UINT16)) ||
((ldtype == NDARRAY_UINT16) && (rdtype == NDARRAY_UINT8)) ||
((ldtype == NDARRAY_INT8) && (rdtype == NDARRAY_UINT16)) ||
((ldtype == NDARRAY_UINT16) && (rdtype == NDARRAY_INT8))) {
return NDARRAY_UINT16;
}
return NDARRAY_FLOAT;
}
// The following five functions are the inverse of the ndarray_get_... functions,
// and write a floating point datum into a void pointer
void ndarray_set_float_uint8(void *data, mp_float_t datum) {
*((uint8_t *)data) = (uint8_t)datum;
}
void ndarray_set_float_int8(void *data, mp_float_t datum) {
*((int8_t *)data) = (int8_t)datum;
}
void ndarray_set_float_uint16(void *data, mp_float_t datum) {
*((uint16_t *)data) = (uint16_t)datum;
}
void ndarray_set_float_int16(void *data, mp_float_t datum) {
*((int16_t *)data) = (int16_t)datum;
}
void ndarray_set_float_float(void *data, mp_float_t datum) {
*((mp_float_t *)data) = datum;
}
// returns a single function pointer, depending on the dtype
void *ndarray_set_float_function(uint8_t dtype) {
if(dtype == NDARRAY_UINT8) {
return ndarray_set_float_uint8;
} else if(dtype == NDARRAY_INT8) {
return ndarray_set_float_int8;
} else if(dtype == NDARRAY_UINT16) {
return ndarray_set_float_uint16;
} else if(dtype == NDARRAY_INT16) {
return ndarray_set_float_int16;
} else {
return ndarray_set_float_float;
}
}
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */
shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
// TODO: replace numerical_reduce_axes with this function, wherever applicable
// This function should be used, whenever a tensor is contracted;
// The shape and strides at `axis` are moved to the zeroth position,
// everything else is aligned to the right
if(!mp_obj_is_int(axis) & (axis != mp_const_none)) {
mp_raise_TypeError(translate("axis must be None, or an integer"));
}
shape_strides _shape_strides;
size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1);
_shape_strides.shape = shape;
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1);
_shape_strides.strides = strides;
_shape_strides.increment = 0;
// this is the contracted dimension (won't be overwritten for axis == None)
_shape_strides.ndim = 0;
memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);
if(axis == mp_const_none) {
return _shape_strides;
}
uint8_t index = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)
if(axis != mp_const_none) { // i.e., axis is an integer
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndarray->ndim;
if((ax < 0) || (ax > ndarray->ndim - 1)) {
mp_raise_ValueError(translate("index out of range"));
}
index = ULAB_MAX_DIMS - ndarray->ndim + ax;
_shape_strides.ndim = ndarray->ndim - 1;
}
// move the value stored at index to the leftmost position, and align everything else to the right
_shape_strides.shape[0] = ndarray->shape[index];
_shape_strides.strides[0] = ndarray->strides[index];
for(uint8_t i = 0; i < index; i++) {
// entries to the right of index must be shifted by one position to the left
_shape_strides.shape[i + 1] = ndarray->shape[i];
_shape_strides.strides[i + 1] = ndarray->strides[i];
}
if(_shape_strides.ndim != 0) {
_shape_strides.increment = 1;
}
return _shape_strides;
}
#if ULAB_MAX_DIMS > 1
ndarray_obj_t *tools_object_is_square(mp_obj_t obj) {
// Returns an ndarray, if the object is a square ndarray,
// raises the appropriate exception otherwise
if(!mp_obj_is_type(obj, &ulab_ndarray_type)) {
mp_raise_TypeError(translate("size is defined for ndarrays only"));
}
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(obj);
if((ndarray->shape[ULAB_MAX_DIMS - 1] != ndarray->shape[ULAB_MAX_DIMS - 2]) || (ndarray->ndim != 2)) {
mp_raise_ValueError(translate("input must be square matrix"));
}
return ndarray;
}
#endif
|