File: Util.hs

package info (click to toggle)
haskell-tls 2.1.8-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,056 kB
  • sloc: haskell: 15,695; makefile: 3
file content (116 lines) | stat: -rw-r--r-- 3,316 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
{-# LANGUAGE ScopedTypeVariables #-}

module Network.TLS.Util (
    sub,
    takelast,
    partition3,
    partition6,
    (&&!),
    fmapEither,
    catchException,
    forEitherM,
    mapChunks_,
    getChunks,
    Saved,
    saveMVar,
    restoreMVar,
) where

import qualified Data.ByteString as B
import Network.TLS.Imports

import Control.Concurrent.MVar
import Control.Exception (SomeAsyncException (..))
import qualified Control.Exception as E

sub :: ByteString -> Int -> Int -> Maybe ByteString
sub b offset len
    | B.length b < offset + len = Nothing
    | otherwise = Just $ B.take len $ snd $ B.splitAt offset b

takelast :: Int -> ByteString -> Maybe ByteString
takelast i b
    | B.length b >= i = sub b (B.length b - i) i
    | otherwise = Nothing

partition3
    :: ByteString -> (Int, Int, Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 bytes (d1, d2, d3)
    | any (< 0) l = Nothing
    | sum l /= B.length bytes = Nothing
    | otherwise = Just (p1, p2, p3)
  where
    l = [d1, d2, d3]
    (p1, r1) = B.splitAt d1 bytes
    (p2, r2) = B.splitAt d2 r1
    (p3, _) = B.splitAt d3 r2

partition6
    :: ByteString
    -> (Int, Int, Int, Int, Int, Int)
    -> Maybe (ByteString, ByteString, ByteString, ByteString, ByteString, ByteString)
partition6 bytes (d1, d2, d3, d4, d5, d6) = if B.length bytes < s then Nothing else Just (p1, p2, p3, p4, p5, p6)
  where
    s = sum [d1, d2, d3, d4, d5, d6]
    (p1, r1) = B.splitAt d1 bytes
    (p2, r2) = B.splitAt d2 r1
    (p3, r3) = B.splitAt d3 r2
    (p4, r4) = B.splitAt d4 r3
    (p5, r5) = B.splitAt d5 r4
    (p6, _) = B.splitAt d6 r5

-- | This is a strict version of &&.
(&&!) :: Bool -> Bool -> Bool
True &&! True = True
True &&! False = False
False &&! True = False
False &&! False = False

fmapEither :: (a -> b) -> Either l a -> Either l b
fmapEither f = fmap f

catchException :: IO a -> (E.SomeException -> IO a) -> IO a
catchException f handler = E.catchJust filterExn f handler
  where
    filterExn :: E.SomeException -> Maybe E.SomeException
    filterExn e = case E.fromException (E.toException e) of
        Just (SomeAsyncException _) -> Nothing
        Nothing -> Just e

forEitherM :: Monad m => [a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [] _ = return (pure [])
forEitherM (x : xs) f = f x >>= doTail
  where
    doTail (Right b) = fmap (b :) <$> forEitherM xs f
    doTail (Left e) = return (Left e)

mapChunks_
    :: Monad m
    => Maybe Int
    -> (B.ByteString -> m a)
    -> B.ByteString
    -> m ()
mapChunks_ len f = mapM_ f . getChunks len

getChunks :: Maybe Int -> B.ByteString -> [B.ByteString]
getChunks Nothing = (: [])
getChunks (Just len) = go
  where
    go bs
        | B.length bs > len =
            let (chunk, remain) = B.splitAt len bs
             in chunk : go remain
        | otherwise = [bs]

-- | An opaque newtype wrapper to prevent from poking inside content that has
-- been saved.
newtype Saved a = Saved a

-- | Save the content of an 'MVar' to restore it later.
saveMVar :: MVar a -> IO (Saved a)
saveMVar ref = Saved <$> readMVar ref

-- | Restore the content of an 'MVar' to a previous saved value and return the
-- content that has just been replaced.
restoreMVar :: MVar a -> Saved a -> IO (Saved a)
restoreMVar ref (Saved val) = Saved <$> swapMVar ref val