1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
|
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module Run (
runTLS,
runTLSSimple,
runTLSPredicate,
runTLSSimple13,
runTLS0RTT,
runTLSSimpleKeyUpdate,
runTLSCapture13,
runTLSSuccess,
runTLSFailure,
expectMaybe,
) where
import Control.Concurrent
import Control.Concurrent.Async
import qualified Control.Exception as E
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.IORef
import Network.TLS
import System.Timeout
import Test.Hspec
import Test.QuickCheck
import API
import Arbitrary
import PipeChan
type ClinetWithInput = Chan ByteString -> Context -> IO ()
type ServerWithOutput = Context -> Chan [ByteString] -> IO ()
----------------------------------------------------------------
runTLS
:: (ClientParams, ServerParams)
-> ClinetWithInput
-> ServerWithOutput
-> IO ()
runTLS = runTLSN 1
runTLSN
:: Int
-> (ClientParams, ServerParams)
-> ClinetWithInput
-> ServerWithOutput
-> IO ()
runTLSN n params tlsClient tlsServer = do
inputChan <- newChan
outputChan <- newChan
-- generate some data to send
ds <- replicateM n $ B.pack <$> generate (someWords8 256)
forM_ ds $ writeChan inputChan
-- run client and server
withPairContext params $ \(cCtx, sCtx) ->
concurrently_ (server sCtx outputChan) (client inputChan cCtx)
-- read result
mDs <- timeout 1000000 $ readChan outputChan -- 60 sec
expectMaybe "timeout" ds mDs
where
server sCtx outputChan =
E.catch
(tlsServer sCtx outputChan)
(printAndRaise "S: " (serverSupported $ snd params))
client inputChan cCtx =
E.catch
(tlsClient inputChan cCtx)
(printAndRaise "C: " (clientSupported $ fst params))
printAndRaise :: String -> Supported -> E.SomeException -> IO ()
printAndRaise s supported e = do
putStrLn $
s
++ " exception: "
++ show e
++ ", supported: "
++ show supported
E.throwIO e
----------------------------------------------------------------
runTLSSimple :: (ClientParams, ServerParams) -> IO ()
runTLSSimple params = runTLSPredicate params (const True)
runTLSPredicate
:: (ClientParams, ServerParams) -> (Maybe Information -> Bool) -> IO ()
runTLSPredicate params p = runTLSSuccess params hsClient hsServer
where
hsClient ctx = do
handshake ctx
checkInfoPredicate ctx
hsServer ctx = do
handshake ctx
checkInfoPredicate ctx
checkInfoPredicate ctx = do
minfo <- contextGetInformation ctx
unless (p minfo) $
fail ("unexpected information: " ++ show minfo)
----------------------------------------------------------------
runTLSSimple13
:: (ClientParams, ServerParams)
-> HandshakeMode13
-> IO ()
runTLSSimple13 params mode =
runTLSSuccess params hsClient hsServer
where
hsClient ctx = do
handshake ctx
mmode <- (>>= infoTLS13HandshakeMode) <$> contextGetInformation ctx
expectMaybe "C: mode should be Just" mode mmode
hsServer ctx = do
handshake ctx
mmode <- (>>= infoTLS13HandshakeMode) <$> contextGetInformation ctx
expectMaybe "S: mode should be Just" mode mmode
runTLS0RTT
:: (ClientParams, ServerParams)
-> HandshakeMode13
-> ByteString
-> IO ()
runTLS0RTT params mode earlyData =
withPairContext params $ \(cCtx, sCtx) ->
concurrently_ (tlsServer sCtx) (tlsClient cCtx)
where
tlsClient ctx = do
handshake ctx
sendData ctx $ L.fromStrict earlyData
_ <- recvData ctx
bye ctx
mmode <- (>>= infoTLS13HandshakeMode) <$> contextGetInformation ctx
expectMaybe "C: mode should be Just" mode mmode
tlsServer ctx = do
handshake ctx
let ls = chunkLengths $ B.length earlyData
chunks <- replicateM (length ls) $ recvData ctx
(map B.length chunks, B.concat chunks) `shouldBe` (ls, earlyData)
sendData ctx $ L.fromStrict earlyData
bye ctx
mmode <- (>>= infoTLS13HandshakeMode) <$> contextGetInformation ctx
expectMaybe "S: mode should be Just" mode mmode
chunkLengths :: Int -> [Int]
chunkLengths len
| len > 16384 = 16384 : chunkLengths (len - 16384)
| len > 0 = [len]
| otherwise = []
expectMaybe :: (Show a, Eq a) => String -> a -> Maybe a -> Expectation
expectMaybe tag e mx = case mx of
Nothing -> expectationFailure tag
Just x -> x `shouldBe` e
runTLSCapture13
:: (ClientParams, ServerParams) -> IO ([Handshake13], [Handshake13])
runTLSCapture13 params = do
sRef <- newIORef []
cRef <- newIORef []
runTLSSuccess params (hsClient cRef) (hsServer sRef)
sReceived <- readIORef sRef
cReceived <- readIORef cRef
return (reverse sReceived, reverse cReceived)
where
hsClient ref ctx = do
installHook ctx ref
handshake ctx
hsServer ref ctx = do
installHook ctx ref
handshake ctx
installHook ctx ref =
let recv hss = modifyIORef ref (hss :) >> return hss
in contextHookSetHandshake13Recv ctx recv
runTLSSimpleKeyUpdate :: (ClientParams, ServerParams) -> IO ()
runTLSSimpleKeyUpdate params = runTLSN 3 params tlsClient tlsServer
where
tlsClient queue ctx = do
handshake ctx
d0 <- readChan queue
sendData ctx (L.fromChunks [d0])
d1 <- readChan queue
sendData ctx (L.fromChunks [d1])
req <- generate $ elements [OneWay, TwoWay]
_ <- updateKey ctx req
d2 <- readChan queue
sendData ctx (L.fromChunks [d2])
checkCtxFinished ctx
bye ctx
tlsServer ctx queue = do
handshake ctx
d0 <- recvData ctx
req <- generate $ elements [OneWay, TwoWay]
_ <- updateKey ctx req
d1 <- recvData ctx
d2 <- recvData ctx
writeChan queue [d0, d1, d2]
checkCtxFinished ctx
bye ctx
----------------------------------------------------------------
runTLSSuccess
:: (ClientParams, ServerParams)
-> (Context -> IO ())
-> (Context -> IO ())
-> IO ()
runTLSSuccess params hsClient hsServer = runTLS params tlsClient tlsServer
where
tlsClient queue ctx = do
hsClient ctx
d <- readChan queue
sendData ctx (L.fromChunks [d])
checkCtxFinished ctx
bye ctx
tlsServer ctx queue = do
hsServer ctx
d <- recvData ctx
writeChan queue [d]
checkCtxFinished ctx
bye ctx
runTLSFailure
:: (ClientParams, ServerParams)
-> (Context -> IO c)
-> (Context -> IO s)
-> IO ()
runTLSFailure params hsClient hsServer =
withPairContext params $ \(cCtx, sCtx) ->
concurrently_ (tlsServer sCtx) (tlsClient cCtx)
where
tlsClient ctx = hsClient ctx `shouldThrow` anyTLSException
tlsServer ctx = hsServer ctx `shouldThrow` anyTLSException
anyTLSException :: Selector TLSException
anyTLSException = const True
----------------------------------------------------------------
debug :: Bool
debug = False
withPairContext
:: (ClientParams, ServerParams) -> ((Context, Context) -> IO ()) -> IO ()
withPairContext params body =
E.bracket
(newPairContext params)
(\((t1, t2), _) -> killThread t1 >> killThread t2)
(\(_, ctxs) -> body ctxs)
newPairContext
:: (ClientParams, ServerParams)
-> IO ((ThreadId, ThreadId), (Context, Context))
newPairContext (cParams, sParams) = do
pipe <- newPipe
tids <- runPipe pipe
let noFlush = return ()
let noClose = return ()
let cBackend = Backend noFlush noClose (writePipeC pipe) (readPipeC pipe)
let sBackend = Backend noFlush noClose (writePipeS pipe) (readPipeS pipe)
cCtx' <- contextNew cBackend cParams
sCtx' <- contextNew sBackend sParams
contextHookSetLogging cCtx' (logging "client: ")
contextHookSetLogging sCtx' (logging "server: ")
return (tids, (cCtx', sCtx'))
where
logging pre =
if debug
then
defaultLogging
{ loggingPacketSent = putStrLn . ((pre ++ ">> ") ++)
, loggingPacketRecv = putStrLn . ((pre ++ "<< ") ++)
}
else defaultLogging
|