File: casting.py

package info (click to toggle)
pypy3 7.3.19%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 212,236 kB
  • sloc: python: 2,098,316; ansic: 540,565; sh: 21,462; asm: 14,419; cpp: 4,451; makefile: 4,209; objc: 761; xml: 530; exp: 499; javascript: 314; pascal: 244; lisp: 45; csh: 12; awk: 4
file content (455 lines) | stat: -rw-r--r-- 16,274 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
"""Functions and helpers for converting between dtypes"""

from rpython.rlib import jit, objectmodel
from rpython.rlib.signature import signature, types as ann
from pypy.interpreter.gateway import unwrap_spec
from pypy.interpreter.error import OperationError, oefmt

from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy import constants as NPY
from .types import (
    BaseType, Bool, ULong, Long, Float64, Complex64,
    StringType, UnicodeType, VoidType, ObjectType,
    int_types, float_types, complex_types, number_types, all_types)
from .descriptor import (
    W_Dtype, get_dtype_cache, as_dtype, is_scalar_w, variable_dtype,
    new_string_dtype, new_unicode_dtype, num2dtype)

@jit.unroll_safe
def result_type(space, __args__):
    args_w, kw_w = __args__.unpack()
    if kw_w:
        raise oefmt(space.w_TypeError,
            "result_type() takes no keyword arguments")
    if not args_w:
        raise oefmt(space.w_ValueError,
            "at least one array or dtype is required")
    arrays_w = []
    dtypes_w = []
    for w_arg in args_w:
        if isinstance(w_arg, W_NDimArray):
            arrays_w.append(w_arg)
        elif is_scalar_w(space, w_arg):
            w_scalar = as_scalar(space, w_arg)
            w_arr = W_NDimArray.from_scalar(space, w_scalar)
            arrays_w.append(w_arr)
        else:
            dtype = as_dtype(space, w_arg)
            dtypes_w.append(dtype)
    return find_result_type(space, arrays_w, dtypes_w)

@jit.look_inside_iff(lambda space, arrays_w, dtypes_w:
    jit.loop_unrolling_heuristic(arrays_w, len(arrays_w)) and
    jit.loop_unrolling_heuristic(dtypes_w, len(dtypes_w)))
def find_result_type(space, arrays_w, dtypes_w):
    # equivalent to PyArray_ResultType
    if len(arrays_w) == 1 and not dtypes_w:
        return arrays_w[0].get_dtype()
    elif not arrays_w and len(dtypes_w) == 1:
        return dtypes_w[0]
    result = None
    if not _use_min_scalar(arrays_w, dtypes_w):
        for w_array in arrays_w:
            if result is None:
                result = w_array.get_dtype()
            else:
                result = promote_types(space, result, w_array.get_dtype())
        for dtype in dtypes_w:
            if result is None:
                result = dtype
            else:
                result = promote_types(space, result, dtype)
    else:
        small_unsigned = False
        for w_array in arrays_w:
            dtype = w_array.get_dtype()
            small_unsigned_scalar = False
            if w_array.is_scalar() and dtype.is_number():
                num, alt_num = w_array.get_scalar_value().min_dtype()
                small_unsigned_scalar = (num != alt_num)
                dtype = num2dtype(space, num)
            if result is None:
                result = dtype
                small_unsigned = small_unsigned_scalar
            else:
                result, small_unsigned = _promote_types_su(
                    space, result, dtype,
                    small_unsigned, small_unsigned_scalar)
        for dtype in dtypes_w:
            if result is None:
                result = dtype
                small_unsigned = False
            else:
                result, small_unsigned = _promote_types_su(
                    space, result, dtype,
                    small_unsigned, False)
    return result

simple_kind_ordering = objectmodel.dict_to_switch({
    Bool.kind: 0, ULong.kind: 1, Long.kind: 1,
    Float64.kind: 2, Complex64.kind: 2,
    NPY.STRINGLTR: 3, NPY.STRINGLTR2: 3,
    UnicodeType.kind: 3, VoidType.kind: 3, ObjectType.kind: 3})

# this is safe to unroll since it'll only be seen if we look inside
# the find_result_type
@jit.unroll_safe
def _use_min_scalar(arrays_w, dtypes_w):
    """Helper for find_result_type()"""
    if not arrays_w:
        return False
    all_scalars = True
    max_scalar_kind = 0
    max_array_kind = 0
    for w_array in arrays_w:
        if w_array.is_scalar():
            kind = simple_kind_ordering(w_array.get_dtype().kind)
            if kind > max_scalar_kind:
                max_scalar_kind = kind
        else:
            all_scalars = False
            kind = simple_kind_ordering(w_array.get_dtype().kind)
            if kind > max_array_kind:
                max_array_kind = kind
    for dtype in dtypes_w:
        all_scalars = False
        kind = simple_kind_ordering(dtype.kind)
        if kind > max_array_kind:
            max_array_kind = kind
    return not all_scalars and max_array_kind >= max_scalar_kind


@unwrap_spec(casting='text')
def can_cast(space, w_from, w_totype, casting='safe'):
    try:
        target = as_dtype(space, w_totype, allow_None=False)
    except TypeError:
        raise oefmt(space.w_TypeError,
            "did not understand one of the types; 'None' not accepted")
    if isinstance(w_from, W_NDimArray):
        return space.newbool(can_cast_array(space, w_from, target, casting))
    elif is_scalar_w(space, w_from):
        w_scalar = as_scalar(space, w_from)
        w_arr = W_NDimArray.from_scalar(space, w_scalar)
        return space.newbool(can_cast_array(space, w_arr, target, casting))

    try:
        origin = as_dtype(space, w_from, allow_None=False)
    except TypeError:
        raise oefmt(space.w_TypeError,
            "did not understand one of the types; 'None' not accepted")
    return space.newbool(can_cast_type(space, origin, target, casting))

kind_ordering = {
    Bool.kind: 0, ULong.kind: 1, Long.kind: 2,
    Float64.kind: 4, Complex64.kind: 5,
    NPY.STRINGLTR: 6, NPY.STRINGLTR2: 6,
    UnicodeType.kind: 7, VoidType.kind: 8, ObjectType.kind: 9}

def can_cast_type(space, origin, target, casting):
    # equivalent to PyArray_CanCastTypeTo
    if origin == target:
        return True
    if casting == 'unsafe':
        return True
    elif casting == 'no':
        return origin.eq(space, target)
    if origin.num == target.num:
        if origin.is_record():
            return (target.is_record() and
                    can_cast_record(space, origin, target, casting))
        else:
            if casting == 'equiv':
                return origin.elsize == target.elsize
            elif casting == 'safe':
                return origin.elsize <= target.elsize
            else:
                return True

    elif casting == 'same_kind':
        if can_cast_to(origin, target):
            return True
        if origin.kind in kind_ordering and target.kind in kind_ordering:
            return kind_ordering[origin.kind] <= kind_ordering[target.kind]
        return False
    elif casting == 'safe':
        return can_cast_to(origin, target)
    else:  # 'equiv'
        return origin.num == target.num and origin.elsize == target.elsize

def can_cast_record(space, origin, target, casting):
    if origin is target:
        return True
    if origin.fields is None or target.fields is None:
        return False
    if len(origin.fields) != len(target.fields):
        return False
    for name, (offset, orig_field) in origin.fields.iteritems():
        if name not in target.fields:
            return False
        target_field = target.fields[name][1]
        if not can_cast_type(space, orig_field, target_field, casting):
            return False
    return True


def can_cast_array(space, w_from, target, casting):
    # equivalent to PyArray_CanCastArrayTo
    origin = w_from.get_dtype()
    if w_from.is_scalar():
        return can_cast_scalar(
            space, origin, w_from.get_scalar_value(), target, casting)
    else:
        return can_cast_type(space, origin, target, casting)

def can_cast_scalar(space, from_type, value, target, casting):
    # equivalent to CNumPy's can_cast_scalar_to
    if from_type == target or casting == 'unsafe':
        return True
    if not from_type.is_number() or casting in ('no', 'equiv'):
        return can_cast_type(space, from_type, target, casting)
    if not from_type.is_native():
        value = value.descr_byteswap(space)
    dtypenum, altnum = value.min_dtype()
    if target.is_unsigned():
        dtypenum = altnum
    dtype = num2dtype(space, dtypenum)
    return can_cast_type(space, dtype, target, casting)

def as_scalar(space, w_obj):
    dtype = scalar2dtype(space, w_obj)
    return dtype.coerce(space, w_obj)

def min_scalar_type(space, w_a):
    w_array = convert_to_array(space, w_a)
    dtype = w_array.get_dtype()
    if w_array.is_scalar() and dtype.is_number():
        num, alt_num = w_array.get_scalar_value().min_dtype()
        return num2dtype(space, num)
    else:
        return dtype

def w_promote_types(space, w_type1, w_type2):
    dt1 = as_dtype(space, w_type1, allow_None=False)
    dt2 = as_dtype(space, w_type2, allow_None=False)
    return promote_types(space, dt1, dt2)

def find_binop_result_dtype(space, dt1, dt2):
    if dt2 is None:
        return dt1
    if dt1 is None:
        return dt2
    return promote_types(space, dt1, dt2)

def promote_types(space, dt1, dt2):
    """Return the smallest dtype to which both input dtypes can be safely cast"""
    # Equivalent to PyArray_PromoteTypes
    num = promotion_table[dt1.num][dt2.num]
    if num != -1:
        return num2dtype(space, num)

    # dt1.num should be <= dt2.num
    if dt1.num > dt2.num:
        dt1, dt2 = dt2, dt1

    if dt2.is_str():
        if dt1.is_str():
            if dt1.elsize > dt2.elsize:
                return dt1
            else:
                return dt2
        else:  # dt1 is numeric
            dt1_size = dt1.itemtype.strlen
            if dt1_size > dt2.elsize:
                return new_string_dtype(space, dt1_size)
            else:
                return dt2
    elif dt2.is_unicode():
        if dt1.is_unicode():
            if dt1.elsize > dt2.elsize:
                return dt1
            else:
                return dt2
        elif dt1.is_str():
            if dt2.elsize >= 4 * dt1.elsize:
                return dt2
            else:
                return new_unicode_dtype(space, dt1.elsize)
        else:  # dt1 is numeric
            dt1_size = dt1.itemtype.strlen
            if 4 * dt1_size > dt2.elsize:
                return new_unicode_dtype(space, dt1_size)
            else:
                return dt2
    else:
        assert dt2.num == NPY.VOID
        if can_cast_type(space, dt1, dt2, casting='equiv'):
            return dt1
    raise oefmt(space.w_TypeError, "invalid type promotion")

def _promote_types_su(space, dt1, dt2, su1, su2):
    """Like promote_types(), but handles the small_unsigned flag as well"""
    if su1:
        if dt2.is_bool() or dt2.is_unsigned():
            dt1 = dt1.as_unsigned(space)
        else:
            dt1 = dt1.as_signed(space)
    elif su2:
        if dt1.is_bool() or dt1.is_unsigned():
            dt2 = dt2.as_unsigned(space)
        else:
            dt2 = dt2.as_signed(space)
    if dt1.elsize < dt2.elsize:
        su = su2 and (su1 or not dt1.is_signed())
    elif dt1.elsize == dt2.elsize:
        su = su1 and su2
    else:
        su = su1 and (su2 or not dt2.is_signed())
    return promote_types(space, dt1, dt2), su

def scalar2dtype(space, w_obj):
    from .boxes import W_GenericBox
    bool_dtype = get_dtype_cache(space).w_booldtype
    long_dtype = get_dtype_cache(space).w_longdtype
    int64_dtype = get_dtype_cache(space).w_int64dtype
    uint64_dtype = get_dtype_cache(space).w_uint64dtype
    complex_dtype = get_dtype_cache(space).w_complex128dtype
    float_dtype = get_dtype_cache(space).w_float64dtype
    object_dtype = get_dtype_cache(space).w_objectdtype
    if isinstance(w_obj, W_GenericBox):
        return w_obj.get_dtype(space)

    if space.isinstance_w(w_obj, space.w_bool):
        return bool_dtype
    elif space.isinstance_w(w_obj, space.w_int):
        try:
            space.int_w(w_obj)
        except OperationError as e:
            if e.match(space, space.w_OverflowError):
                if space.is_true(space.le(w_obj, space.newint(0))):
                    return int64_dtype
                return uint64_dtype
            raise
        return int64_dtype
    elif space.isinstance_w(w_obj, space.w_float):
        return float_dtype
    elif space.isinstance_w(w_obj, space.w_complex):
        return complex_dtype
    elif space.isinstance_w(w_obj, space.w_bytes):
        return variable_dtype(space, 'S%d' % space.len_w(w_obj))
    elif space.isinstance_w(w_obj, space.w_unicode):
        return new_unicode_dtype(space, space.len_w(w_obj))
    return object_dtype

@signature(ann.instance(W_Dtype), ann.instance(W_Dtype), returns=ann.bool())
def can_cast_to(dt1, dt2):
    """Return whether dtype `dt1` can be cast safely to `dt2`"""
    # equivalent to PyArray_CanCastTo
    from .casting import can_cast_itemtype
    result = can_cast_itemtype(dt1.itemtype, dt2.itemtype)
    if result:
        if dt1.num == NPY.STRING:
            if dt2.num == NPY.STRING:
                return dt1.elsize <= dt2.elsize
            elif dt2.num == NPY.UNICODE:
                return dt1.elsize * 4 <= dt2.elsize
        elif dt1.num == NPY.UNICODE and dt2.num == NPY.UNICODE:
            return dt1.elsize <= dt2.elsize
        elif dt2.num in (NPY.STRING, NPY.UNICODE):
            if dt2.num == NPY.STRING:
                char_size = 1
            else:  # NPY.UNICODE
                char_size = 4
            if dt2.elsize == 0:
                return True
            if dt1.is_int():
                return dt2.elsize >= dt1.itemtype.strlen * char_size
    return result


@signature(ann.instance(BaseType), ann.instance(BaseType), returns=ann.bool())
def can_cast_itemtype(tp1, tp2):
    # equivalent to PyArray_CanCastSafely
    return casting_table[tp1.num][tp2.num]

#_________________________


casting_table = [[False] * NPY.NTYPES for _ in range(NPY.NTYPES)]

def enable_cast(type1, type2):
    casting_table[type1.num][type2.num] = True

def _can_cast(type1, type2):
    """NOT_RPYTHON: operates on BaseType subclasses"""
    return casting_table[type1.num][type2.num]

for tp in all_types:
    enable_cast(tp, tp)
    if tp.num != NPY.DATETIME:
        enable_cast(Bool, tp)
    enable_cast(tp, ObjectType)
    enable_cast(tp, VoidType)
enable_cast(StringType, UnicodeType)
#enable_cast(Bool, TimeDelta)

for tp in number_types:
    enable_cast(tp, StringType)
    enable_cast(tp, UnicodeType)

for tp1 in int_types:
    for tp2 in int_types:
        if tp1.signed:
            if tp2.signed and tp1.basesize() <= tp2.basesize():
                enable_cast(tp1, tp2)
        else:
            if tp2.signed and tp1.basesize() < tp2.basesize():
                enable_cast(tp1, tp2)
            elif not tp2.signed and tp1.basesize() <= tp2.basesize():
                enable_cast(tp1, tp2)
for tp1 in int_types:
    for tp2 in float_types + complex_types:
        size1 = tp1.basesize()
        size2 = tp2.basesize()
        if (size1 < 8 and size2 > size1) or (size1 >= 8 and size2 >= size1):
            enable_cast(tp1, tp2)
for tp1 in float_types:
    for tp2 in float_types + complex_types:
        if tp1.basesize() <= tp2.basesize():
            enable_cast(tp1, tp2)
for tp1 in complex_types:
    for tp2 in complex_types:
        if tp1.basesize() <= tp2.basesize():
            enable_cast(tp1, tp2)

promotion_table = [[-1] * NPY.NTYPES for _ in range(NPY.NTYPES)]
def promotes(tp1, tp2, tp3):
    if tp3 is None:
        num = -1
    else:
        num = tp3.num
    promotion_table[tp1.num][tp2.num] = num


for tp in all_types:
    promotes(tp, ObjectType, ObjectType)
    promotes(ObjectType, tp, ObjectType)

for tp1 in [Bool] + number_types:
    for tp2 in [Bool] + number_types:
        if tp1 is tp2:
            promotes(tp1, tp1, tp1)
        elif _can_cast(tp1, tp2):
            promotes(tp1, tp2, tp2)
        elif _can_cast(tp2, tp1):
            promotes(tp1, tp2, tp1)
        else:
            # Brute-force search for the least upper bound
            result = None
            for tp3 in number_types:
                if _can_cast(tp1, tp3) and _can_cast(tp2, tp3):
                    if result is None:
                        result = tp3
                    elif _can_cast(tp3, result) and not _can_cast(result, tp3):
                        result = tp3
            promotes(tp1, tp2, result)