1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
|
{-# LANGUAGE OverloadedStrings #-}
import Network.Wai
import Network.Wai.Handler.Warp
import qualified Data.IORef as I
import Control.Monad.IO.Class (MonadIO, liftIO)
import Network.HTTP.Types
import Control.Concurrent (forkIO, killThread, threadDelay)
import Control.Monad (forM_)
import System.IO (hFlush, hClose)
import System.IO.Unsafe (unsafePerformIO)
import Data.ByteString (ByteString, hPutStr, hGetSome)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Network (connectTo, PortID (PortNumber))
import Test.Hspec.Monadic
import Test.Hspec.HUnit ()
import Test.HUnit
import Data.Conduit (($$))
import qualified Data.Conduit.List
type Counter = I.IORef (Either String Int)
type CounterApplication = Counter -> Application
incr :: MonadIO m => Counter -> m ()
incr icount = liftIO $ I.atomicModifyIORef icount $ \ecount ->
((case ecount of
Left s -> Left s
Right i -> Right $ i + 1), ())
err :: (MonadIO m, Show a) => Counter -> a -> m ()
err icount msg = liftIO $ I.writeIORef icount $ Left $ show msg
readBody :: CounterApplication
readBody icount req = do
body <- requestBody req $$ Data.Conduit.List.consume
case () of
()
| pathInfo req == ["hello"] && L.fromChunks body /= "Hello"
-> err icount ("Invalid hello" :: String, body)
| requestMethod req == "GET" && L.fromChunks body /= ""
-> err icount ("Invalid GET" :: String, body)
| not $ requestMethod req `elem` ["GET", "POST"]
-> err icount ("Invalid request method (readBody)" :: String, requestMethod req)
| otherwise -> incr icount
return $ responseLBS status200 [] "Read the body"
ignoreBody :: CounterApplication
ignoreBody icount req = do
if (requestMethod req `elem` ["GET", "POST"])
then incr icount
else err icount ("Invalid request method" :: String, requestMethod req)
return $ responseLBS status200 [] "Ignored the body"
doubleConnect :: CounterApplication
doubleConnect icount req = do
_ <- requestBody req $$ Data.Conduit.List.consume
_ <- requestBody req $$ Data.Conduit.List.consume
incr icount
return $ responseLBS status200 [] "double connect"
nextPort :: I.IORef Int
nextPort = unsafePerformIO $ I.newIORef 5000
getPort :: IO Int
getPort = I.atomicModifyIORef nextPort $ \p -> (p + 1, p)
runTest :: Int -- ^ expected number of requests
-> CounterApplication
-> [ByteString] -- ^ chunks to send
-> IO ()
runTest expected app chunks = do
port <- getPort
ref <- I.newIORef (Right 0)
tid <- forkIO $ run port $ app ref
threadDelay 1000
handle <- connectTo "127.0.0.1" $ PortNumber $ fromIntegral port
forM_ chunks $ \chunk -> hPutStr handle chunk >> hFlush handle
_ <- hGetSome handle 4096
threadDelay 1000
killThread tid
res <- I.readIORef ref
case res of
Left s -> error s
Right i -> i @?= expected
dummyApp :: Application
dummyApp _ = return $ responseLBS status200 [] "foo"
runTerminateTest :: InvalidRequest
-> ByteString
-> IO ()
runTerminateTest expected input = do
port <- getPort
ref <- I.newIORef Nothing
tid <- forkIO $ runSettings defaultSettings
{ settingsOnException = \e -> I.writeIORef ref $ Just e
, settingsPort = port
} dummyApp
threadDelay 1000
handle <- connectTo "127.0.0.1" $ PortNumber $ fromIntegral port
hPutStr handle input
hFlush handle
hClose handle
threadDelay 1000
killThread tid
res <- I.readIORef ref
show res @?= show (Just expected)
singleGet :: ByteString
singleGet = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
singlePostHello :: ByteString
singlePostHello = "POST /hello HTTP/1.1\r\nHost: localhost\r\nContent-length: 5\r\n\r\nHello"
main :: IO ()
main = hspecX $ do
describe "non-pipelining" $ do
it "no body, read" $ runTest 5 readBody $ replicate 5 singleGet
it "no body, ignore" $ runTest 5 ignoreBody $ replicate 5 singleGet
it "has body, read" $ runTest 2 readBody
[ singlePostHello
, singleGet
]
it "has body, ignore" $ runTest 2 ignoreBody
[ singlePostHello
, singleGet
]
describe "pipelining" $ do
it "no body, read" $ runTest 5 readBody [S.concat $ replicate 5 singleGet]
it "no body, ignore" $ runTest 5 ignoreBody [S.concat $ replicate 5 singleGet]
it "has body, read" $ runTest 2 readBody $ return $ S.concat
[ singlePostHello
, singleGet
]
it "has body, ignore" $ runTest 2 ignoreBody $ return $ S.concat
[ singlePostHello
, singleGet
]
describe "no hanging" $ do
it "has body, read" $ runTest 1 readBody $ map S.singleton $ S.unpack singlePostHello
it "double connect" $ runTest 1 doubleConnect [singlePostHello]
describe "connection termination" $ do
it "ConnectionClosedByPeer" $ runTerminateTest ConnectionClosedByPeer "GET / HTTP/1.1\r\ncontent-length: 10\r\n\r\nhello"
it "IncompleteHeaders" $ runTerminateTest IncompleteHeaders "GET / HTTP/1.1\r\ncontent-length: 10\r\n"
describe "chunked bodies" $ do
it "works" $ do
ifront <- I.newIORef id
port <- getPort
tid <- forkIO $ run port $ \req -> do
bss <- requestBody req $$ Data.Conduit.List.consume
liftIO $ I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss:), ())
return $ responseLBS status200 [] ""
threadDelay 1000
handle <- connectTo "127.0.0.1" $ PortNumber $ fromIntegral port
let input = S.concat
[ "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n"
, "c\r\nHello World\n\r\n3\r\nBye\r\n0\r\n"
, "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n"
, "b\r\nHello World\r\n0\r\n"
]
hPutStr handle input
hFlush handle
hClose handle
threadDelay 1000
killThread tid
front <- I.readIORef ifront
front [] @?=
[ "Hello World\nBye"
, "Hello World"
]
it "lots of chunks" $ do
ifront <- I.newIORef id
port <- getPort
tid <- forkIO $ run port $ \req -> do
bss <- requestBody req $$ Data.Conduit.List.consume
liftIO $ I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss:), ())
return $ responseLBS status200 [] ""
threadDelay 1000
handle <- connectTo "127.0.0.1" $ PortNumber $ fromIntegral port
let input = concat $ replicate 2 $
["POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n"] ++
(replicate 50 "5\r\n12345\r\n") ++
["0\r\n"]
mapM_ (\bs -> hPutStr handle bs >> hFlush handle) input
hClose handle
threadDelay 1000
killThread tid
front <- I.readIORef ifront
front [] @?= replicate 2 (S.concat $ replicate 50 "12345")
it "in chunks" $ do
ifront <- I.newIORef id
port <- getPort
tid <- forkIO $ run port $ \req -> do
bss <- requestBody req $$ Data.Conduit.List.consume
liftIO $ I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss:), ())
return $ responseLBS status200 [] ""
threadDelay 1000
handle <- connectTo "127.0.0.1" $ PortNumber $ fromIntegral port
let input = S.concat
[ "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n"
, "c\r\nHello World\n\r\n3\r\nBye\r\n0\r\n"
, "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n"
, "b\r\nHello World\r\n0\r\n"
]
mapM_ (\bs -> hPutStr handle bs >> hFlush handle) $ map S.singleton $ S.unpack input
hClose handle
threadDelay 1000
killThread tid
front <- I.readIORef ifront
front [] @?=
[ "Hello World\nBye"
, "Hello World"
]
|