1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
module Network.TLS.IO (
sendPacket12,
sendPacket13,
recvPacket12,
recvPacket13,
--
isRecvComplete,
checkValid,
-- * Grouping multiple packets in the same flight
PacketFlightM,
runPacketFlight,
loadPacket13,
) where
import Control.Exception (finally, throwIO)
import Control.Monad.Reader
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Data.IORef
import Network.TLS.Context.Internal
import Network.TLS.Hooks
import Network.TLS.IO.Decode
import Network.TLS.IO.Encode
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.Record
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
----------------------------------------------------------------
-- | Send one packet to the context
sendPacket12 :: Context -> Packet -> IO ()
sendPacket12 ctx@Context{ctxRecordLayer = recordLayer} pkt = do
-- in ver <= TLS1.0, block ciphers using CBC are using CBC residue as IV, which can be guessed
-- by an attacker. Hence, an empty packet is sent before a normal data packet, to
-- prevent guessability.
when (isNonNullAppData pkt) $ do
withEmptyPacket <- readIORef $ ctxNeedEmptyPacket ctx
when withEmptyPacket $
writePacketBytes12 ctx recordLayer (AppData B.empty)
>>= recordSendBytes recordLayer ctx
writePacketBytes12 ctx recordLayer pkt >>= recordSendBytes recordLayer ctx
where
isNonNullAppData (AppData b) = not $ B.null b
isNonNullAppData _ = False
writePacketBytes12
:: Monoid bytes
=> Context
-> RecordLayer bytes
-> Packet
-> IO bytes
writePacketBytes12 ctx recordLayer pkt = do
withLog ctx $ \logging -> loggingPacketSent logging (show pkt)
edataToSend <- encodePacket12 ctx recordLayer pkt
either throwCore return edataToSend
----------------------------------------------------------------
sendPacket13 :: Context -> Packet13 -> IO ()
sendPacket13 ctx@Context{ctxRecordLayer = recordLayer} pkt =
writePacketBytes13 ctx recordLayer pkt >>= recordSendBytes recordLayer ctx
writePacketBytes13
:: Monoid bytes
=> Context
-> RecordLayer bytes
-> Packet13
-> IO bytes
writePacketBytes13 ctx recordLayer pkt = do
withLog ctx $ \logging -> loggingPacketSent logging (show pkt)
edataToSend <- encodePacket13 ctx recordLayer pkt
either throwCore return edataToSend
----------------------------------------------------------------
-- | receive one packet from the context that contains 1 or
-- many messages (many only in case of handshake). if will returns a
-- TLSError if the packet is unexpected or malformed
recvPacket12 :: Context -> IO (Either TLSError Packet)
recvPacket12 ctx@Context{ctxRecordLayer = recordLayer} = loop 0
where
lim = limitHandshakeFragment $ sharedLimit $ ctxShared ctx
loop count
| count > lim = do
let err = Error_Packet "too many handshake fragment"
logPacket ctx $ show err
return $ Left err
loop count = do
hrr <- usingState_ ctx getTLS13HRR
erecord <- recordRecv12 recordLayer ctx
case erecord of
Left err -> do
logPacket ctx $ show err
return $ Left err
Right record
| hrr && isCCS record -> loop (count + 1)
| otherwise -> do
pktRecv <- decodePacket12 ctx record
if isEmptyHandshake pktRecv
then do
logPacket ctx "Handshake fragment"
-- When a handshake record is fragmented
-- we continue receiving in order to feed
-- stHandshakeRecordCont
loop (count + 1)
else case pktRecv of
Right (Handshake hss) -> do
pktRecv'@(Right pkt) <- ctxWithHooks ctx $ \hooks ->
Right . Handshake <$> mapM (hookRecvHandshake hooks) hss
logPacket ctx $ show pkt
return pktRecv'
Right pkt -> do
logPacket ctx $ show pkt
return pktRecv
Left err -> do
logPacket ctx $ show err
return pktRecv
isCCS :: Record a -> Bool
isCCS (Record ProtocolType_ChangeCipherSpec _ _) = True
isCCS _ = False
isEmptyHandshake :: Either TLSError Packet -> Bool
isEmptyHandshake (Right (Handshake [])) = True
isEmptyHandshake _ = False
logPacket :: Context -> String -> IO ()
logPacket ctx msg = withLog ctx $ \logging -> loggingPacketRecv logging msg
----------------------------------------------------------------
recvPacket13 :: Context -> IO (Either TLSError Packet13)
recvPacket13 ctx@Context{ctxRecordLayer = recordLayer} = loop 0
where
lim = limitHandshakeFragment $ sharedLimit $ ctxShared ctx
loop count
| count > lim =
return $ Left $ Error_Packet "too many handshake fragment"
loop count = do
erecord <- recordRecv13 recordLayer ctx
case erecord of
Left err@(Error_Protocol _ BadRecordMac) -> do
-- If the server decides to reject RTT0 data but accepts RTT1
-- data, the server should skip all records for RTT0 data.
logPacket ctx $ show err
established <- ctxEstablished ctx
case established of
EarlyDataNotAllowed n
| n > 0 -> do
setEstablished ctx $ EarlyDataNotAllowed (n - 1)
loop (count + 1)
_ -> return $ Left err
Left err -> do
logPacket ctx $ show err
return $ Left err
Right record -> do
pktRecv <- decodePacket13 ctx record
if isEmptyHandshake13 pktRecv
then do
logPacket ctx "Handshake fragment"
-- When a handshake record is fragmented we
-- continue receiving in order to feed
-- stHandshakeRecordCont13
loop (count + 1)
else do
case pktRecv of
Right (Handshake13 hss) -> do
pktRecv'@(Right pkt) <- ctxWithHooks ctx $ \hooks ->
Right . Handshake13 <$> mapM (hookRecvHandshake13 hooks) hss
logPacket ctx $ show pkt
return pktRecv'
Right pkt -> do
logPacket ctx $ show pkt
return pktRecv
Left err -> do
logPacket ctx $ show err
return pktRecv
isEmptyHandshake13 :: Either TLSError Packet13 -> Bool
isEmptyHandshake13 (Right (Handshake13 [])) = True
isEmptyHandshake13 _ = False
----------------------------------------------------------------
isRecvComplete :: Context -> IO Bool
isRecvComplete ctx = usingState_ ctx $ do
cont <- gets stHandshakeRecordCont
cont13 <- gets stHandshakeRecordCont13
return $ isNothing cont && isNothing cont13
checkValid :: Context -> IO ()
checkValid ctx = do
established <- ctxEstablished ctx
when (established == NotEstablished) $ throwIO ConnectionNotEstablished
eofed <- ctxEOF ctx
when eofed $ throwIO $ PostHandshake Error_EOF
----------------------------------------------------------------
type Builder b = [b] -> [b]
-- | State monad used to group several packets together and send them on wire as
-- single flight. When packets are loaded in the monad, they are logged
-- immediately, update the context digest and transcript, but actual sending is
-- deferred. Packets are sent all at once when the monadic computation ends
-- (normal termination but also if interrupted by an exception).
newtype PacketFlightM b a
= PacketFlightM (ReaderT (RecordLayer b, IORef (Builder b)) IO a)
deriving (Functor, Applicative, Monad, MonadFail, MonadIO)
runPacketFlight :: Context -> (forall b. Monoid b => PacketFlightM b a) -> IO a
runPacketFlight ctx@Context{ctxRecordLayer = recordLayer} (PacketFlightM f) = do
ref <- newIORef id
runReaderT f (recordLayer, ref) `finally` sendPendingFlight ctx recordLayer ref
sendPendingFlight
:: Monoid b => Context -> RecordLayer b -> IORef (Builder b) -> IO ()
sendPendingFlight ctx recordLayer ref = do
build <- readIORef ref
let bss = build []
unless (null bss) $ recordSendBytes recordLayer ctx $ mconcat bss
loadPacket13 :: Monoid b => Context -> Packet13 -> PacketFlightM b ()
loadPacket13 ctx pkt = PacketFlightM $ do
(recordLayer, ref) <- ask
liftIO $ do
bs <- writePacketBytes13 ctx recordLayer pkt
modifyIORef ref (. (bs :))
|