File: Connection.hs

package info (click to toggle)
haskell-connection 0.3.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 124 kB
  • sloc: haskell: 333; makefile: 2
file content (427 lines) | stat: -rw-r--r-- 17,991 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- Module      : Network.Connection
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : portable
--
-- Simple connection abstraction
--
module Network.Connection
    (
    -- * Type for a connection
      Connection
    , connectionID
    , ConnectionParams(..)
    , TLSSettings(..)
    , ProxySettings(..)
    , SockSettings

    -- * Exceptions
    , LineTooLong(..)
    , HostNotResolved(..)
    , HostCannotConnect(..)

    -- * Library initialization
    , initConnectionContext
    , ConnectionContext

    -- * Connection operation
    , connectFromHandle
    , connectFromSocket
    , connectTo
    , connectionClose

    -- * Sending and receiving data
    , connectionGet
    , connectionGetExact
    , connectionGetChunk
    , connectionGetChunk'
    , connectionGetLine
    , connectionWaitForInput
    , connectionPut

    -- * TLS related operation
    , connectionSetSecure
    , connectionIsSecure
    , connectionSessionManager
    ) where

import Control.Concurrent.MVar
import Control.Monad (join)
import qualified Control.Exception as E
import qualified System.IO.Error as E (mkIOError, eofErrorType)

import qualified Network.TLS as TLS
import qualified Network.TLS.Extra as TLS

import System.X509 (getSystemCertificateStore)

import Network.Socks5 (defaultSocksConf, socksConnectWithSocket, SocksAddress(..), SocksHostAddress(..))
import Network.Socket
import qualified Network.Socket.ByteString as N

import Data.Tuple (swap)
import Data.Default.Class
import Data.Data
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as L

import System.Environment
import System.Timeout
import System.IO
import qualified Data.Map as M

import Network.Connection.Types

type Manager = MVar (M.Map TLS.SessionID TLS.SessionData)

-- | This is the exception raised if we reached the user specified limit for
-- the line in ConnectionGetLine.
data LineTooLong = LineTooLong deriving (Show,Typeable)

-- | Exception raised when there's no resolution for a specific host
data HostNotResolved = HostNotResolved String deriving (Show,Typeable)

-- | Exception raised when the connect failed
data HostCannotConnect = HostCannotConnect String [E.IOException] deriving (Show,Typeable)

instance E.Exception LineTooLong
instance E.Exception HostNotResolved
instance E.Exception HostCannotConnect

connectionSessionManager :: Manager -> TLS.SessionManager
connectionSessionManager mvar = TLS.SessionManager
    { TLS.sessionResume     = \sessionID -> withMVar mvar (return . M.lookup sessionID)
    , TLS.sessionEstablish  = \sessionID sessionData ->
                               modifyMVar_ mvar (return . M.insert sessionID sessionData)
    , TLS.sessionInvalidate = \sessionID -> modifyMVar_ mvar (return . M.delete sessionID)
#if MIN_VERSION_tls(1,5,0)
    , TLS.sessionResumeOnlyOnce = \sessionID ->
         modifyMVar mvar (pure . swap . M.updateLookupWithKey (\_ _ -> Nothing) sessionID)
#endif
    }

-- | Initialize the library with shared parameters between connection.
initConnectionContext :: IO ConnectionContext
initConnectionContext = ConnectionContext <$> getSystemCertificateStore

-- | Create a final TLS 'ClientParams' according to the destination and the
-- TLSSettings.
makeTLSParams :: ConnectionContext -> ConnectionID -> TLSSettings -> TLS.ClientParams
makeTLSParams cg cid ts@(TLSSettingsSimple {}) =
    (TLS.defaultParamsClient (fst cid) portString)
        { TLS.clientSupported = def { TLS.supportedCiphers = TLS.ciphersuite_default }
        , TLS.clientShared    = def
            { TLS.sharedCAStore         = globalCertificateStore cg
            , TLS.sharedValidationCache = validationCache
            -- , TLS.sharedSessionManager  = connectionSessionManager
            }
        }
  where validationCache
            | settingDisableCertificateValidation ts =
                TLS.ValidationCache (\_ _ _ -> return TLS.ValidationCachePass)
                                    (\_ _ _ -> return ())
            | otherwise = def
        portString = BC.pack $ show $ snd cid
makeTLSParams _ cid (TLSSettings p) =
    p { TLS.clientServerIdentification = (fst cid, portString) }
 where portString = BC.pack $ show $ snd cid

withBackend :: (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend f conn = readMVar (connectionBackend conn) >>= f

connectionNew :: ConnectionID -> ConnectionBackend -> IO Connection
connectionNew cid backend =
    Connection <$> newMVar backend
               <*> newMVar (Just B.empty)
               <*> pure cid

-- | Use an already established handle to create a connection object.
--
-- if the TLS Settings is set, it will do the handshake with the server.
-- The SOCKS settings have no impact here, as the handle is already established
connectFromHandle :: ConnectionContext
                  -> Handle
                  -> ConnectionParams
                  -> IO Connection
connectFromHandle cg h p = withSecurity (connectionUseSecure p)
    where withSecurity Nothing            = connectionNew cid $ ConnectionStream h
          withSecurity (Just tlsSettings) = tlsEstablish h (makeTLSParams cg cid tlsSettings) >>= connectionNew cid . ConnectionTLS
          cid = (connectionHostname p, connectionPort p)

-- | Use an already established handle to create a connection object.
--
-- if the TLS Settings is set, it will do the handshake with the server.
-- The SOCKS settings have no impact here, as the handle is already established
connectFromSocket :: ConnectionContext
                  -> Socket
                  -> ConnectionParams
                  -> IO Connection
connectFromSocket cg sock p = withSecurity (connectionUseSecure p)
    where withSecurity Nothing            = connectionNew cid $ ConnectionSocket sock
          withSecurity (Just tlsSettings) = tlsEstablish sock (makeTLSParams cg cid tlsSettings) >>= connectionNew cid . ConnectionTLS
          cid = (connectionHostname p, connectionPort p)

-- | connect to a destination using the parameter
connectTo :: ConnectionContext -- ^ The global context of this connection.
          -> ConnectionParams  -- ^ The parameters for this connection (where to connect, and such).
          -> IO Connection     -- ^ The new established connection on success.
connectTo cg cParams = do
    let conFct = doConnect (connectionUseSocks cParams)
                           (connectionHostname cParams)
                           (connectionPort cParams)
    E.bracketOnError conFct (close . fst) $ \(h, _) ->
        connectFromSocket cg h cParams
  where
    sockConnect sockHost sockPort h p = do
        (sockServ, servAddr) <- resolve' sockHost sockPort
        let sockConf = defaultSocksConf servAddr
        let destAddr = SocksAddress (SocksAddrDomainName $ BC.pack h) p
        (dest, _) <- socksConnectWithSocket sockServ sockConf destAddr
        case dest of
            SocksAddrIPV4 h4 -> return (sockServ, SockAddrInet p h4)
            SocksAddrIPV6 h6 -> return (sockServ, SockAddrInet6 p 0 h6 0)
            SocksAddrDomainName _ -> error "internal error: socks connect return a resolved address as domain name"


    doConnect proxy h p =
        case proxy of
            Nothing                 -> resolve' h p
            Just (OtherProxy proxyHost proxyPort) -> resolve' proxyHost proxyPort
            Just (SockSettingsSimple sockHost sockPort) ->
                sockConnect sockHost sockPort h p
            Just (SockSettingsEnvironment envName) -> do
                -- if we can't get the environment variable or that the string cannot be parsed
                -- we connect directly.
                let name = maybe "SOCKS_SERVER" id envName
                evar <- E.try (getEnv name)
                case evar of
                    Left (_ :: E.IOException) -> resolve' h p
                    Right var                 ->
                        case parseSocks var of
                            Nothing                   -> resolve' h p
                            Just (sockHost, sockPort) -> sockConnect sockHost sockPort h p

    -- Try to parse "host:port" or "host"
    -- if port is ommited then the default SOCKS port (1080) is assumed
    parseSocks :: String -> Maybe (String, PortNumber)
    parseSocks s =
        case break (== ':') s of
            (sHost, "")        -> Just (sHost, 1080)
            (sHost, ':':portS) ->
                case reads portS of
                    [(sPort,"")] -> Just (sHost, sPort)
                    _            -> Nothing
            _                  -> Nothing

    -- Try to resolve the host/port into an address (zero to many of them), then
    -- try to connect from the first address to the last, returning the first one that
    -- succeed
    resolve' :: String -> PortNumber -> IO (Socket, SockAddr)
    resolve' host port = do
        let hints = defaultHints { addrSocketType = Stream }
        addrs <- getAddrInfo (Just hints) (Just host) (Just $ show port)
        firstSuccessful $ map tryToConnect addrs
      where
        tryToConnect addr =
            E.bracketOnError
                (socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr))
                (close)
                (\sock -> connect sock (addrAddress addr) >> return (sock, addrAddress addr))
        firstSuccessful = go []
          where
            go :: [E.IOException] -> [IO a] -> IO a
            go []      [] = E.throwIO $ HostNotResolved host
            go l@(_:_) [] = E.throwIO $ HostCannotConnect host l
            go acc     (act:followingActs) = do
                er <- E.try act
                case er of
                    Left err -> go (err:acc) followingActs
                    Right r  -> return r

-- | Put a block of data in the connection.
connectionPut :: Connection -> ByteString -> IO ()
connectionPut connection content = withBackend doWrite connection
    where doWrite (ConnectionStream h) = B.hPut h content >> hFlush h
          doWrite (ConnectionSocket s) = N.sendAll s content
          doWrite (ConnectionTLS ctx)  = TLS.sendData ctx $ L.fromChunks [content]

-- | Get exact count of bytes from a connection.
--
-- The size argument is the exact amount that must be returned to the user.
-- The call will wait until all data is available.  Hence, it behaves like
-- 'B.hGet'.
--
-- On end of input, 'connectionGetExact' will throw an 'E.isEOFError'
-- exception.
connectionGetExact :: Connection -> Int -> IO ByteString
connectionGetExact conn x = loop B.empty 0
  where loop bs y
          | y == x = return bs
          | otherwise = do
            next <- connectionGet conn (x - y)
            loop (B.append bs next) (y + (B.length next))

-- | Get some bytes from a connection.
--
-- The size argument is just the maximum that could be returned to the user.
-- The call will return as soon as there's data, even if there's less
-- than requested.  Hence, it behaves like 'B.hGetSome'.
--
-- On end of input, 'connectionGet' returns 0, but subsequent calls will throw
-- an 'E.isEOFError' exception.
connectionGet :: Connection -> Int -> IO ByteString
connectionGet conn size
  | size < 0  = fail "Network.Connection.connectionGet: size < 0"
  | size == 0 = return B.empty
  | otherwise = connectionGetChunkBase "connectionGet" conn $ B.splitAt size

-- | Get the next block of data from the connection.
connectionGetChunk :: Connection -> IO ByteString
connectionGetChunk conn =
    connectionGetChunkBase "connectionGetChunk" conn $ \s -> (s, B.empty)

-- | Like 'connectionGetChunk', but return the unused portion to the buffer,
-- where it will be the next chunk read.
connectionGetChunk' :: Connection -> (ByteString -> (a, ByteString)) -> IO a
connectionGetChunk' = connectionGetChunkBase "connectionGetChunk'"

-- | Wait for input to become available on a connection.
--
-- As with 'hWaitForInput', the timeout value is given in milliseconds.  If the
-- timeout value is less than zero, then 'connectionWaitForInput' waits
-- indefinitely.
--
-- Unlike 'hWaitForInput', this function does not do any decoding, so it
-- returns true when there is /any/ available input, not just full characters.
connectionWaitForInput :: Connection -> Int -> IO Bool
connectionWaitForInput conn timeout_ms = maybe False (const True) <$> timeout timeout_ns tryGetChunk
  where tryGetChunk = connectionGetChunkBase "connectionWaitForInput" conn $ \buf -> ((), buf)
        timeout_ns  = timeout_ms * 1000

connectionGetChunkBase :: String -> Connection -> (ByteString -> (a, ByteString)) -> IO a
connectionGetChunkBase loc conn f =
    modifyMVar (connectionBuffer conn) $ \m ->
        case m of
            Nothing -> throwEOF conn loc
            Just buf
              | B.null buf -> do
                  chunk <- withBackend getMoreData conn
                  if B.null chunk
                     then closeBuf chunk
                     else updateBuf chunk
              | otherwise ->
                  updateBuf buf
  where
    getMoreData (ConnectionTLS tlsctx) = TLS.recvData tlsctx
    getMoreData (ConnectionSocket sock) = N.recv sock 1500
    getMoreData (ConnectionStream h)   = B.hGetSome h (16 * 1024)

    updateBuf buf = case f buf of (a, !buf') -> return (Just buf', a)
    closeBuf  buf = case f buf of (a, _buf') -> return (Nothing, a)

-- | Get the next line, using ASCII LF as the line terminator.
--
-- This throws an 'isEOFError' exception on end of input, and LineTooLong when
-- the number of bytes gathered is over the limit without a line terminator.
--
-- The actual line returned can be bigger than the limit specified, provided
-- that the last chunk returned by the underlaying backend contains a LF.
-- In another world only when we need more input and limit is reached that the
-- LineTooLong exception will be raised.
--
-- An end of file will be considered as a line terminator too, if line is
-- not empty.
connectionGetLine :: Int           -- ^ Maximum number of bytes before raising a LineTooLong exception
                  -> Connection    -- ^ Connection
                  -> IO ByteString -- ^ The received line with the LF trimmed
connectionGetLine limit conn = more (throwEOF conn loc) 0 id
  where
    loc = "connectionGetLine"
    lineTooLong = E.throwIO LineTooLong

    -- Accumulate chunks using a difference list, and concatenate them
    -- when an end-of-line indicator is reached.
    more eofK !currentSz !dl =
        getChunk (\s -> let len = B.length s
                         in if currentSz + len > limit
                               then lineTooLong
                               else more eofK (currentSz + len) (dl . (s:)))
                 (\s -> done (dl . (s:)))
                 (done dl)

    done :: ([ByteString] -> [ByteString]) -> IO ByteString
    done dl = return $! B.concat $ dl []

    -- Get another chunk, and call one of the continuations
    getChunk :: (ByteString -> IO r) -- moreK: need more input
             -> (ByteString -> IO r) -- doneK: end of line (line terminator found)
             -> IO r                 -- eofK:  end of file
             -> IO r
    getChunk moreK doneK eofK =
      join $ connectionGetChunkBase loc conn $ \s ->
        if B.null s
          then (eofK, B.empty)
          else case B.break (== 10) s of
                 (a, b)
                   | B.null b  -> (moreK a, B.empty)
                   | otherwise -> (doneK a, B.tail b)

throwEOF :: Connection -> String -> IO a
throwEOF conn loc =
    E.throwIO $ E.mkIOError E.eofErrorType loc' Nothing (Just path)
  where
    loc' = "Network.Connection." ++ loc
    path = let (host, port) = connectionID conn
            in host ++ ":" ++ show port

-- | Close a connection.
connectionClose :: Connection -> IO ()
connectionClose = withBackend backendClose
    where backendClose (ConnectionTLS ctx)  = ignoreIOExc (TLS.bye ctx) `E.finally` TLS.contextClose ctx
          backendClose (ConnectionSocket sock) = close sock
          backendClose (ConnectionStream h) = hClose h

          ignoreIOExc action = action `E.catch` \(_ :: E.IOException) -> return ()

-- | Activate secure layer using the parameters specified.
--
-- This is typically used to negociate a TLS channel on an already
-- establish channel, e.g. supporting a STARTTLS command. it also
-- flush the received buffer to prevent application confusing
-- received data before and after the setSecure call.
--
-- If the connection is already using TLS, nothing else happens.
connectionSetSecure :: ConnectionContext
                    -> Connection
                    -> TLSSettings
                    -> IO ()
connectionSetSecure cg connection params =
    modifyMVar_ (connectionBuffer connection) $ \b ->
    modifyMVar (connectionBackend connection) $ \backend ->
        case backend of
            (ConnectionStream h) -> do ctx <- tlsEstablish h (makeTLSParams cg (connectionID connection) params)
                                       return (ConnectionTLS ctx, Just B.empty)
            (ConnectionSocket s) -> do ctx <- tlsEstablish s (makeTLSParams cg (connectionID connection) params)
                                       return (ConnectionTLS ctx, Just B.empty)
            (ConnectionTLS _)    -> return (backend, b)

-- | Returns if the connection is establish securely or not.
connectionIsSecure :: Connection -> IO Bool
connectionIsSecure conn = withBackend isSecure conn
    where isSecure (ConnectionStream _) = return False
          isSecure (ConnectionSocket _) = return False
          isSecure (ConnectionTLS _)    = return True

tlsEstablish :: TLS.HasBackend backend => backend -> TLS.ClientParams -> IO TLS.Context
tlsEstablish handle tlsParams = do
    ctx <- TLS.contextNew handle tlsParams
    TLS.handshake ctx
    return ctx