1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
|
from __future__ import annotations
__all__ = [
"ndarray",
"Device",
"Dtype",
]
import sys
from typing import (
Union,
TYPE_CHECKING,
)
from cupy import (
ndarray,
dtype,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
)
from cupy.cuda.device import Device
if TYPE_CHECKING or sys.version_info >= (3, 9):
Dtype = dtype[Union[
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
]]
else:
Dtype = dtype
|