1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE MultiWayIf #-}
import Unsafe.Coerce (unsafeCoerce)
import Data.Monoid ((<>))
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as VM
import Foreign.C.Types
import Foreign.ForeignPtr (newForeignPtr_)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable)
import qualified Language.C.Inline as C
import qualified Language.C.Inline.Unsafe as CU
import System.IO.Unsafe (unsafePerformIO)
import Control.Monad (forM_)
import System.IO (withFile, hPutStrLn, IOMode(..))
C.context (C.baseCtx <> C.vecCtx <> C.funCtx)
C.include "<gsl/gsl_errno.h>"
C.include "<gsl/gsl_matrix.h>"
C.include "<gsl/gsl_odeiv2.h>"
-- | Solves a system of ODEs. Every 'V.Vector' involved must be of the
-- same size.
{-# NOINLINE solveOdeC #-}
solveOdeC
:: (CDouble -> V.Vector CDouble -> V.Vector CDouble)
-- ^ ODE to Solve
-> CDouble
-- ^ Start
-> V.Vector CDouble
-- ^ Solution at start point
-> CDouble
-- ^ End
-> Either String (V.Vector CDouble)
-- ^ Solution at end point, or error.
solveOdeC fun x0 f0 xend = unsafePerformIO $ do
let dim = V.length f0
let dim_c = fromIntegral dim -- This is in CInt
-- Convert the function to something of the right type to C.
let funIO x y f _ptr = do
-- Convert the pointer we get from C (y) to a vector, and then
-- apply the user-supplied function.
fImm <- fun x <$> vectorFromC dim y
-- Fill in the provided pointer with the resulting vector.
vectorToC fImm dim f
-- Unsafe since the function will be called many times.
[CU.exp| int{ GSL_SUCCESS } |]
-- Create a mutable vector from the initial solution. This will be
-- passed to the ODE solving function provided by GSL, and will
-- contain the final solution.
fMut <- V.thaw f0
res <- [C.block| int {
gsl_odeiv2_system sys = {
$fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)),
// The ODE to solve, converted to function pointer using the `fun`
// anti-quoter
NULL, // We don't provide a Jacobian
$(int dim_c), // The dimension
NULL // We don't need the parameter pointer
};
// Create the driver, using some sensible values for the stepping
// function and the tolerances
gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new (
&sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0);
// Finally, apply the driver.
int status = gsl_odeiv2_driver_apply(
d, &$(double x0), $(double xend), $vec-ptr:(double *fMut));
// Free the driver
gsl_odeiv2_driver_free(d);
return status;
} |]
-- Check the error code
maxSteps <- [C.exp| int{ GSL_EMAXITER } |]
smallStep <- [C.exp| int{ GSL_ENOPROG } |]
good <- [C.exp| int{ GSL_SUCCESS } |]
if | res == good -> Right <$> V.freeze fMut
| res == maxSteps -> return $ Left "Too many steps"
| res == smallStep -> return $ Left "Step size dropped below minimum allowed size"
| otherwise -> return $ Left $ "Unknown error code " ++ show res
solveOde
:: (Double -> V.Vector Double -> V.Vector Double)
-- ^ ODE to Solve
-> Double
-- ^ Start
-> V.Vector Double
-- ^ Solution at start point
-> Double
-- ^ End
-> Either String (V.Vector Double)
-- ^ Solution at end point, or error.
solveOde fun x0 f0 xend =
unsafeCoerce $ solveOdeC (unsafeCoerce fun) (unsafeCoerce x0) (unsafeCoerce f0) (unsafeCoerce xend)
lorenz
:: Double
-- ^ Starting point
-> V.Vector Double
-- ^ Solution at starting point
-> Double
-- ^ End point
-> Either String (V.Vector Double)
lorenz x0 f0 xend = solveOde fun x0 f0 xend
where
sigma = 10.0;
_R = 28.0;
b = 8.0 / 3.0;
fun _x y =
let y0 = y V.! 0
y1 = y V.! 1
y2 = y V.! 2
in V.fromList
[ sigma * ( y1 - y0 )
, _R * y0 - y1 - y0 * y2
, -b * y2 + y0 * y1
]
main :: IO ()
main = withFile "lorenz.csv" WriteMode $ \h ->
forM_ pts $ \(x,y) ->
hPutStrLn h $ show x ++ ", " ++ show y
where
pts = [(f V.! 0, f V.! 2) | (_x, f) <- go 0 (V.fromList [10.0 , 1.0 , 1.0])]
go x f | x > 40 =
[(x, f)]
go x f =
let x' = x + 0.01
Right f' = lorenz x f x'
in (x, f) : go x' f'
-- Utils
vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
vectorFromC len ptr = do
ptr' <- newForeignPtr_ ptr
V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
vectorToC vec len ptr = do
ptr' <- newForeignPtr_ ptr
V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
|