File: Run.hs

package info (click to toggle)
haskell-tls 2.1.8-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,056 kB
  • sloc: haskell: 15,695; makefile: 3
file content (279 lines) | stat: -rw-r--r-- 8,453 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module Run (
    runTLS,
    runTLSSimple,
    runTLSPredicate,
    runTLSSimple13,
    runTLS0RTT,
    runTLSSimpleKeyUpdate,
    runTLSCapture13,
    runTLSSuccess,
    runTLSFailure,
    expectMaybe,
) where

import Control.Concurrent
import Control.Concurrent.Async
import qualified Control.Exception as E
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.IORef
import Network.TLS
import System.Timeout
import Test.Hspec
import Test.QuickCheck

import API
import Arbitrary
import PipeChan

type ClinetWithInput = Chan ByteString -> Context -> IO ()
type ServerWithOutput = Context -> Chan [ByteString] -> IO ()

----------------------------------------------------------------

runTLS
    :: (ClientParams, ServerParams)
    -> ClinetWithInput
    -> ServerWithOutput
    -> IO ()
runTLS = runTLSN 1

runTLSN
    :: Int
    -> (ClientParams, ServerParams)
    -> ClinetWithInput
    -> ServerWithOutput
    -> IO ()
runTLSN n params tlsClient tlsServer = do
    inputChan <- newChan
    outputChan <- newChan
    -- generate some data to send
    ds <- replicateM n $ B.pack <$> generate (someWords8 256)
    forM_ ds $ writeChan inputChan
    -- run client and server
    withPairContext params $ \(cCtx, sCtx) ->
        concurrently_ (server sCtx outputChan) (client inputChan cCtx)
    -- read result
    mDs <- timeout 1000000 $ readChan outputChan -- 60 sec
    expectMaybe "timeout" ds mDs
  where
    server sCtx outputChan =
        E.catch
            (tlsServer sCtx outputChan)
            (printAndRaise "S: " (serverSupported $ snd params))
    client inputChan cCtx =
        E.catch
            (tlsClient inputChan cCtx)
            (printAndRaise "C: " (clientSupported $ fst params))
    printAndRaise :: String -> Supported -> E.SomeException -> IO ()
    printAndRaise s supported e = do
        putStrLn $
            s
                ++ " exception: "
                ++ show e
                ++ ", supported: "
                ++ show supported
        E.throwIO e

----------------------------------------------------------------

runTLSSimple :: (ClientParams, ServerParams) -> IO ()
runTLSSimple params = runTLSPredicate params (const True)

runTLSPredicate
    :: (ClientParams, ServerParams) -> (Maybe Information -> Bool) -> IO ()
runTLSPredicate params p = runTLSSuccess params hsClient hsServer
  where
    hsClient ctx = do
        handshake ctx
        checkInfoPredicate ctx
    hsServer ctx = do
        handshake ctx
        checkInfoPredicate ctx
    checkInfoPredicate ctx = do
        minfo <- contextGetInformation ctx
        unless (p minfo) $
            fail ("unexpected information: " ++ show minfo)

----------------------------------------------------------------

runTLSSimple13
    :: (ClientParams, ServerParams)
    -> HandshakeMode13
    -> IO ()
runTLSSimple13 params mode =
    runTLSSuccess params hsClient hsServer
  where
    hsClient ctx = do
        handshake ctx
        mmode <- (>>= infoTLS13HandshakeMode) <$> contextGetInformation ctx
        expectMaybe "C: mode should be Just" mode mmode
    hsServer ctx = do
        handshake ctx
        mmode <- (>>= infoTLS13HandshakeMode) <$> contextGetInformation ctx
        expectMaybe "S: mode should be Just" mode mmode

runTLS0RTT
    :: (ClientParams, ServerParams)
    -> HandshakeMode13
    -> ByteString
    -> IO ()
runTLS0RTT params mode earlyData =
    withPairContext params $ \(cCtx, sCtx) ->
        concurrently_ (tlsServer sCtx) (tlsClient cCtx)
  where
    tlsClient ctx = do
        handshake ctx
        sendData ctx $ L.fromStrict earlyData
        _ <- recvData ctx
        bye ctx
        mmode <- (>>= infoTLS13HandshakeMode) <$> contextGetInformation ctx
        expectMaybe "C: mode should be Just" mode mmode
    tlsServer ctx = do
        handshake ctx
        let ls = chunkLengths $ B.length earlyData
        chunks <- replicateM (length ls) $ recvData ctx
        (map B.length chunks, B.concat chunks) `shouldBe` (ls, earlyData)
        sendData ctx $ L.fromStrict earlyData
        bye ctx
        mmode <- (>>= infoTLS13HandshakeMode) <$> contextGetInformation ctx
        expectMaybe "S: mode should be Just" mode mmode
    chunkLengths :: Int -> [Int]
    chunkLengths len
        | len > 16384 = 16384 : chunkLengths (len - 16384)
        | len > 0 = [len]
        | otherwise = []

expectMaybe :: (Show a, Eq a) => String -> a -> Maybe a -> Expectation
expectMaybe tag e mx = case mx of
    Nothing -> expectationFailure tag
    Just x -> x `shouldBe` e

runTLSCapture13
    :: (ClientParams, ServerParams) -> IO ([Handshake13], [Handshake13])
runTLSCapture13 params = do
    sRef <- newIORef []
    cRef <- newIORef []
    runTLSSuccess params (hsClient cRef) (hsServer sRef)
    sReceived <- readIORef sRef
    cReceived <- readIORef cRef
    return (reverse sReceived, reverse cReceived)
  where
    hsClient ref ctx = do
        installHook ctx ref
        handshake ctx
    hsServer ref ctx = do
        installHook ctx ref
        handshake ctx
    installHook ctx ref =
        let recv hss = modifyIORef ref (hss :) >> return hss
         in contextHookSetHandshake13Recv ctx recv

runTLSSimpleKeyUpdate :: (ClientParams, ServerParams) -> IO ()
runTLSSimpleKeyUpdate params = runTLSN 3 params tlsClient tlsServer
  where
    tlsClient queue ctx = do
        handshake ctx
        d0 <- readChan queue
        sendData ctx (L.fromChunks [d0])
        d1 <- readChan queue
        sendData ctx (L.fromChunks [d1])
        req <- generate $ elements [OneWay, TwoWay]
        _ <- updateKey ctx req
        d2 <- readChan queue
        sendData ctx (L.fromChunks [d2])
        checkCtxFinished ctx
        bye ctx
    tlsServer ctx queue = do
        handshake ctx
        d0 <- recvData ctx
        req <- generate $ elements [OneWay, TwoWay]
        _ <- updateKey ctx req
        d1 <- recvData ctx
        d2 <- recvData ctx
        writeChan queue [d0, d1, d2]
        checkCtxFinished ctx
        bye ctx

----------------------------------------------------------------

runTLSSuccess
    :: (ClientParams, ServerParams)
    -> (Context -> IO ())
    -> (Context -> IO ())
    -> IO ()
runTLSSuccess params hsClient hsServer = runTLS params tlsClient tlsServer
  where
    tlsClient queue ctx = do
        hsClient ctx
        d <- readChan queue
        sendData ctx (L.fromChunks [d])
        checkCtxFinished ctx
        bye ctx
    tlsServer ctx queue = do
        hsServer ctx
        d <- recvData ctx
        writeChan queue [d]
        checkCtxFinished ctx
        bye ctx

runTLSFailure
    :: (ClientParams, ServerParams)
    -> (Context -> IO c)
    -> (Context -> IO s)
    -> IO ()
runTLSFailure params hsClient hsServer =
    withPairContext params $ \(cCtx, sCtx) ->
        concurrently_ (tlsServer sCtx) (tlsClient cCtx)
  where
    tlsClient ctx = hsClient ctx `shouldThrow` anyTLSException
    tlsServer ctx = hsServer ctx `shouldThrow` anyTLSException

anyTLSException :: Selector TLSException
anyTLSException = const True

----------------------------------------------------------------

debug :: Bool
debug = False

withPairContext
    :: (ClientParams, ServerParams) -> ((Context, Context) -> IO ()) -> IO ()
withPairContext params body =
    E.bracket
        (newPairContext params)
        (\((t1, t2), _) -> killThread t1 >> killThread t2)
        (\(_, ctxs) -> body ctxs)

newPairContext
    :: (ClientParams, ServerParams)
    -> IO ((ThreadId, ThreadId), (Context, Context))
newPairContext (cParams, sParams) = do
    pipe <- newPipe
    tids <- runPipe pipe
    let noFlush = return ()
    let noClose = return ()

    let cBackend = Backend noFlush noClose (writePipeC pipe) (readPipeC pipe)
    let sBackend = Backend noFlush noClose (writePipeS pipe) (readPipeS pipe)
    cCtx' <- contextNew cBackend cParams
    sCtx' <- contextNew sBackend sParams

    contextHookSetLogging cCtx' (logging "client: ")
    contextHookSetLogging sCtx' (logging "server: ")

    return (tids, (cCtx', sCtx'))
  where
    logging pre =
        if debug
            then
                defaultLogging
                    { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++)
                    , loggingPacketRecv = putStrLn . ((pre ++ "<< ") ++)
                    }
            else defaultLogging