1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
{-# LANGUAGE BangPatterns, FlexibleContexts #-}
-- |
-- Module : Statistics.Transform
-- Copyright : (c) 2011 Bryan O'Sullivan
-- License : BSD3
--
-- Maintainer : bos@serpentine.com
-- Stability : experimental
-- Portability : portable
--
-- Fourier-related transformations of mathematical functions.
--
-- These functions are written for simplicity and correctness, not
-- speed. If you need a fast FFT implementation for your application,
-- you should strongly consider using a library of FFTW bindings
-- instead.
module Statistics.Transform
(
-- * Type synonyms
CD
-- * Discrete cosine transform
, dct
, dct_
, idct
, idct_
-- * Fast Fourier transform
, fft
, ifft
) where
import Control.Monad (when)
import Control.Monad.ST (ST)
import Data.Bits (shiftL, shiftR)
import Data.Complex (Complex(..), conjugate, realPart)
import Numeric.SpecFunctions (log2)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
type CD = Complex Double
-- | Discrete cosine transform (DCT-II).
dct :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v Double -> v Double
dct = dctWorker . G.map (:+0)
{-# INLINABLE dct #-}
{-# SPECIAlIZE dct :: U.Vector Double -> U.Vector Double #-}
{-# SPECIAlIZE dct :: V.Vector Double -> V.Vector Double #-}
-- | Discrete cosine transform (DCT-II). Only real part of vector is
-- transformed, imaginary part is ignored.
dct_ :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v CD -> v Double
dct_ = dctWorker . G.map (\(i :+ _) -> i :+ 0)
{-# INLINABLE dct_ #-}
{-# SPECIAlIZE dct_ :: U.Vector CD -> U.Vector Double #-}
{-# SPECIAlIZE dct_ :: V.Vector CD -> V.Vector Double#-}
dctWorker :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v CD -> v Double
{-# INLINE dctWorker #-}
dctWorker xs
-- length 1 is special cased because shuffle algorithms fail for it.
| G.length xs == 1 = G.map ((2*) . realPart) xs
| vectorOK xs = G.map realPart $ G.zipWith (*) weights (fft interleaved)
| otherwise = error "Statistics.Transform.dct: bad vector length"
where
interleaved = G.backpermute xs $ G.enumFromThenTo 0 2 (len-2) G.++
G.enumFromThenTo (len-1) (len-3) 1
weights = G.cons 2 . G.generate (len-1) $ \x ->
2 * exp ((0:+(-1))*fi (x+1)*pi/(2*n))
where n = fi len
len = G.length xs
-- | Inverse discrete cosine transform (DCT-III). It's inverse of
-- 'dct' only up to scale parameter:
--
-- > (idct . dct) x = (* length x)
idct :: (G.Vector v CD, G.Vector v Double) => v Double -> v Double
idct = idctWorker . G.map (:+0)
{-# INLINABLE idct #-}
{-# SPECIAlIZE idct :: U.Vector Double -> U.Vector Double #-}
{-# SPECIAlIZE idct :: V.Vector Double -> V.Vector Double #-}
-- | Inverse discrete cosine transform (DCT-III). Only real part of vector is
-- transformed, imaginary part is ignored.
idct_ :: (G.Vector v CD, G.Vector v Double) => v CD -> v Double
idct_ = idctWorker . G.map (\(i :+ _) -> i :+ 0)
{-# INLINABLE idct_ #-}
{-# SPECIAlIZE idct_ :: U.Vector CD -> U.Vector Double #-}
{-# SPECIAlIZE idct_ :: V.Vector CD -> V.Vector Double #-}
idctWorker :: (G.Vector v CD, G.Vector v Double) => v CD -> v Double
{-# INLINE idctWorker #-}
idctWorker xs
| vectorOK xs = G.generate len interleave
| otherwise = error "Statistics.Transform.dct: bad vector length"
where
interleave z | even z = vals `G.unsafeIndex` halve z
| otherwise = vals `G.unsafeIndex` (len - halve z - 1)
vals = G.map realPart . ifft $ G.zipWith (*) weights xs
weights
= G.cons n
$ G.generate (len - 1) $ \x -> 2 * n * exp ((0:+1) * fi (x+1) * pi/(2*n))
where n = fi len
len = G.length xs
-- | Inverse fast Fourier transform.
ifft :: G.Vector v CD => v CD -> v CD
ifft xs
| vectorOK xs = G.map ((/fi (G.length xs)) . conjugate) . fft . G.map conjugate $ xs
| otherwise = error "Statistics.Transform.ifft: bad vector length"
{-# INLINABLE ifft #-}
{-# SPECIAlIZE ifft :: U.Vector CD -> U.Vector CD #-}
{-# SPECIAlIZE ifft :: V.Vector CD -> V.Vector CD #-}
-- | Radix-2 decimation-in-time fast Fourier transform.
fft :: G.Vector v CD => v CD -> v CD
fft v | vectorOK v = G.create $ do mv <- G.thaw v
mfft mv
return mv
| otherwise = error "Statistics.Transform.fft: bad vector length"
{-# INLINABLE fft #-}
{-# SPECIAlIZE fft :: U.Vector CD -> U.Vector CD #-}
{-# SPECIAlIZE fft :: V.Vector CD -> V.Vector CD #-}
-- Vector length must be power of two. It's not checked
mfft :: (M.MVector v CD) => v s CD -> ST s ()
{-# INLINE mfft #-}
mfft vec = bitReverse 0 0
where
bitReverse i j | i == len-1 = stage 0 1
| otherwise = do
when (i < j) $ M.swap vec i j
let inner k l | k <= l = inner (k `shiftR` 1) (l-k)
| otherwise = bitReverse (i+1) (l+k)
inner (len `shiftR` 1) j
stage l !l1 | l == m = return ()
| otherwise = do
let !l2 = l1 `shiftL` 1
!e = -6.283185307179586/fromIntegral l2
flight j !a | j == l1 = stage (l+1) l2
| otherwise = do
let butterfly i | i >= len = flight (j+1) (a+e)
| otherwise = do
let i1 = i + l1
xi1 :+ yi1 <- M.read vec i1
let !c = cos a
!s = sin a
d = (c*xi1 - s*yi1) :+ (s*xi1 + c*yi1)
ci <- M.read vec i
M.write vec i1 (ci - d)
M.write vec i (ci + d)
butterfly (i+l2)
butterfly j
flight 0 0
len = M.length vec
m = log2 len
----------------------------------------------------------------
-- Helpers
----------------------------------------------------------------
fi :: Int -> CD
fi = fromIntegral
halve :: Int -> Int
halve = (`shiftR` 1)
vectorOK :: G.Vector v a => v a -> Bool
{-# INLINE vectorOK #-}
vectorOK v = (1 `shiftL` log2 n) == n where n = G.length v
|