1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
|
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Rank2Types #-}
module Data.Conduit.Cereal.Internal
( ConduitErrorHandler
, SinkErrorHandler
, SinkTerminationHandler
, mkConduitGet
, mkSinkGet
) where
import Control.Monad (forever, when)
import qualified Data.ByteString as BS
import qualified Data.Conduit as C
import Data.Serialize hiding (get, put)
-- | What should we do if the Get fails?
type ConduitErrorHandler m o = String -> C.Conduit BS.ByteString m o
type SinkErrorHandler m r = String -> C.Consumer BS.ByteString m r
-- | What should we do if the stream is done before the Get is done?
type SinkTerminationHandler m r = (BS.ByteString -> Result r) -> C.Consumer BS.ByteString m r
-- | Construct a conduitGet with the specified 'ErrorHandler'
mkConduitGet :: Monad m
=> ConduitErrorHandler m o
-> Get o
-> C.Conduit BS.ByteString m o
mkConduitGet errorHandler get = consume True (runGetPartial get) [] BS.empty
where pull f b s
| BS.null s = C.await >>= maybe (when (not $ null b) (C.leftover $ BS.concat $ reverse b)) (pull f b)
| otherwise = consume False f b s
consume initial f b s = case f s of
Fail msg _ -> do
when (not $ null b) (C.leftover $ BS.concat $ reverse consumed)
errorHandler msg
Partial p -> pull p consumed BS.empty
Done a s' -> case initial of
-- this only works because the Get will either _always_ consume no input, or _never_ consume no input.
True -> forever $ C.yield a
False -> C.yield a >> pull (runGetPartial get) [] s'
-- False -> C.yield a >> C.leftover s' >> mkConduitGet errorHandler get
where consumed = s : b
-- | Construct a sinkGet with the specified 'ErrorHandler' and 'TerminationHandler'
mkSinkGet :: Monad m
=> SinkErrorHandler m r
-> SinkTerminationHandler m r
-> Get r
-> C.Consumer BS.ByteString m r
mkSinkGet errorHandler terminationHandler get = consume (runGetPartial get) [] BS.empty
where pull f b s
| BS.null s = C.await >>= \ x -> case x of
Nothing -> when (not $ null b) (C.leftover $ BS.concat $ reverse b) >> terminationHandler f
Just a -> pull f b a
| otherwise = consume f b s
consume f b s = case f s of
Fail msg _ -> do
when (not $ null b) (C.leftover $ BS.concat $ reverse consumed)
errorHandler msg
Partial p -> pull p consumed BS.empty
Done r s' -> when (not $ BS.null s') (C.leftover s') >> return r
where consumed = s : b
|