File: _cwt.pyx

package info (click to toggle)
pywavelets 1.4.1-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,680 kB
  • sloc: python: 8,849; ansic: 5,134; makefile: 93
file content (125 lines) | stat: -rw-r--r-- 5,940 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#cython: boundscheck=False, wraparound=False
from . cimport common
from . cimport c_wt
from .common cimport pywt_index_t, MODE
from ._pywt cimport _check_dtype

cimport numpy as np
import numpy as np

np.import_array()


cpdef cwt_psi_single(data_t[::1] data, ContinuousWavelet wavelet, size_t output_len):
    cdef np.ndarray psi, psi_r, psi_i
    cdef size_t data_size = data.size
    cdef int family_number = 0
    cdef double bandwidth_frequency
    cdef double center_frequency
    cdef int fbsp_order
    if output_len < 1:
        raise RuntimeError("Invalid output length.")

    #if data_t is np.float64_t:
        # TODO: Don't think these have to be 0-initialized
        # TODO: Check other methods of allocating (e.g. Cython/CPython arrays)
    if data_t is np.float64_t:
        if wavelet.short_family_name == "gaus":
            psi = np.zeros(output_len, np.float64)
            family_number = wavelet.family_number
            with nogil:
                c_wt.double_gaus(&data[0], <double *>psi.data, data_size, family_number)
            return psi
        elif wavelet.short_family_name == "mexh":
            psi = np.zeros(output_len, np.float64)
            with nogil:
                c_wt.double_mexh(&data[0], <double *>psi.data, data_size)
            return psi
        elif wavelet.short_family_name == "morl":
            psi = np.zeros(output_len, np.float64)
            with nogil:
                c_wt.double_morl(&data[0], <double *>psi.data, data_size)
            return psi
        elif wavelet.short_family_name == "cgau":
            psi_r = np.zeros(output_len, np.float64)
            psi_i = np.zeros(output_len, np.float64)
            family_number = wavelet.family_number
            with nogil:
                c_wt.double_cgau(&data[0], <double *>psi_r.data, <double *>psi_i.data, data_size, family_number)
            return (psi_r, psi_i)
        elif wavelet.short_family_name == "shan":
            psi_r = np.zeros(output_len, np.float64)
            psi_i = np.zeros(output_len, np.float64)
            bandwidth_frequency = wavelet.bandwidth_frequency
            center_frequency = wavelet.center_frequency
            with nogil:
                c_wt.double_shan(&data[0], <double *>psi_r.data, <double *>psi_i.data, data_size, bandwidth_frequency, center_frequency)
            return (psi_r, psi_i)
        elif wavelet.short_family_name == "fbsp":
            psi_r = np.zeros(output_len, np.float64)
            psi_i = np.zeros(output_len, np.float64)
            fbsp_order = wavelet.fbsp_order
            bandwidth_frequency = wavelet.bandwidth_frequency
            center_frequency = wavelet.center_frequency
            with nogil:
                c_wt.double_fbsp(&data[0], <double *>psi_r.data, <double *>psi_i.data, data_size, fbsp_order, bandwidth_frequency, center_frequency)
            return (psi_r, psi_i)
        elif wavelet.short_family_name == "cmor":
            psi_r = np.zeros(output_len, np.float64)
            psi_i = np.zeros(output_len, np.float64)
            bandwidth_frequency = wavelet.bandwidth_frequency
            center_frequency = wavelet.center_frequency
            with nogil:
                c_wt.double_cmor(&data[0], <double *>psi_r.data, <double *>psi_i.data, data_size, bandwidth_frequency, center_frequency)
            return (psi_r, psi_i)

    elif data_t is np.float32_t:
        if wavelet.short_family_name == "gaus":
            psi = np.zeros(output_len, np.float32)
            family_number = wavelet.family_number
            with nogil:
                c_wt.float_gaus(&data[0], <float *>psi.data, data_size, family_number)
            return psi
        elif wavelet.short_family_name == "mexh":
            psi = np.zeros(output_len, np.float32)
            with nogil:
                c_wt.float_mexh(&data[0], <float *>psi.data, data_size)
            return psi
        elif wavelet.short_family_name == "morl":
            psi = np.zeros(output_len, np.float32)
            with nogil:
                c_wt.float_morl(&data[0], <float *>psi.data, data_size)
            return psi
        elif wavelet.short_family_name == "cgau":
            psi_r = np.zeros(output_len, np.float32)
            psi_i = np.zeros(output_len, np.float32)
            family_number = wavelet.family_number
            with nogil:
                c_wt.float_cgau(&data[0], <float *>psi_r.data, <float *>psi_i.data, data_size, family_number)
            return (psi_r, psi_i)
        elif wavelet.short_family_name == "shan":
            psi_r = np.zeros(output_len, np.float32)
            psi_i = np.zeros(output_len, np.float32)
            bandwidth_frequency = wavelet.bandwidth_frequency
            center_frequency = wavelet.center_frequency
            with nogil:
                c_wt.float_shan(&data[0], <float *>psi_r.data, <float *>psi_i.data, data_size, bandwidth_frequency, center_frequency)
            return (psi_r, psi_i)
        elif wavelet.short_family_name == "fbsp":
            psi_r = np.zeros(output_len, np.float32)
            psi_i = np.zeros(output_len, np.float32)
            fbsp_order = wavelet.fbsp_order
            bandwidth_frequency = wavelet.bandwidth_frequency
            center_frequency = wavelet.center_frequency
            with nogil:
                c_wt.float_fbsp(&data[0], <float *>psi_r.data, <float *>psi_i.data, data_size, fbsp_order, bandwidth_frequency, center_frequency)
            return (psi_r, psi_i)
        elif wavelet.short_family_name == "cmor":
            psi_r = np.zeros(output_len, np.float32)
            psi_i = np.zeros(output_len, np.float32)
            bandwidth_frequency = wavelet.bandwidth_frequency
            center_frequency = wavelet.center_frequency
            with nogil:
                c_wt.float_cmor(&data[0], <float *>psi_r.data, <float *>psi_i.data, data_size, bandwidth_frequency, center_frequency)
            return (psi_r, psi_i)