File: Parsers.hs

package info (click to toggle)
haskell-dns 4.2.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 380 kB
  • sloc: haskell: 3,298; ansic: 46; makefile: 2
file content (508 lines) | stat: -rw-r--r-- 18,130 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
{-# LANGUAGE BangPatterns, LambdaCase, OverloadedStrings #-}

module Network.DNS.Decode.Parsers (
    getResponse
  , getDNSFlags
  , getHeader
  , getResourceRecord
  , getResourceRecords
  , getDomain
  , getMailbox
  ) where

import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BS
import qualified Data.CaseInsensitive as CI
import qualified Data.IP
import Data.IP (IP(..), toIPv4, toIPv6b, makeAddrRange)

import Network.DNS.Imports
import Network.DNS.StateBinary
import Network.DNS.Types.Internal

----------------------------------------------------------------

getResponse :: SGet DNSMessage
getResponse = do
    hm <- getHeader
    qdCount <- getInt16
    anCount <- getInt16
    nsCount <- getInt16
    arCount <- getInt16
    queries <- getQueries qdCount
    answers <- getResourceRecords anCount
    authrrs <- getResourceRecords nsCount
    addnrrs <- getResourceRecords arCount
    let (opts, rest) = partition ((==) OPT. rrtype) addnrrs
        flgs         = flags hm
        rc           = fromRCODE $ rcode flgs
        (eh, erc)    = getEDNS rc opts
        hd           = hm { flags = flgs { rcode = erc } }
    pure $ DNSMessage hd eh queries answers authrrs $ ifEDNS eh rest addnrrs

  where

    -- | Get EDNS pseudo-header and the high eight bits of the extended RCODE.
    --
    getEDNS :: Word16 -> AdditionalRecords -> (EDNSheader, RCODE)
    getEDNS rc rrs = case rrs of
        [rr] | Just (edns, erc) <- optEDNS rr
               -> (EDNSheader edns, toRCODE erc)
        []     -> (NoEDNS, toRCODE rc)
        _      -> (InvalidEDNS, BadRCODE)

      where

        -- | Extract EDNS information from an OPT RR.
        --
        optEDNS :: ResourceRecord -> Maybe (EDNS, Word16)
        optEDNS (ResourceRecord "." OPT udpsiz ttl' (RD_OPT opts)) =
            let hrc      = fromIntegral rc .&. 0x0f
                erc      = shiftR (ttl' .&. 0xff000000) 20 .|. hrc
                secok    = ttl' `testBit` 15
                vers     = fromIntegral $ shiftR (ttl' .&. 0x00ff0000) 16
             in Just (EDNS vers udpsiz secok opts, fromIntegral erc)
        optEDNS _ = Nothing

----------------------------------------------------------------

getDNSFlags :: SGet DNSFlags
getDNSFlags = do
    flgs <- get16
    oc <- getOpcode flgs
    return $ DNSFlags (getQorR flgs)
                      oc
                      (getAuthAnswer flgs)
                      (getTrunCation flgs)
                      (getRecDesired flgs)
                      (getRecAvailable flgs)
                      (getRcode flgs)
                      (getAuthenData flgs)
                      (getChkDisable flgs)
  where
    getQorR w = if testBit w 15 then QR_Response else QR_Query
    getOpcode w =
        case shiftR w 11 .&. 0x0f of
            n | Just opc <- toOPCODE n
              -> pure opc
              | otherwise
              -> failSGet $ "Unsupported header opcode: " ++ show n
    getAuthAnswer w = testBit w 10
    getTrunCation w = testBit w 9
    getRecDesired w = testBit w 8
    getRecAvailable w = testBit w 7
    getRcode w = toRCODE $ w .&. 0x0f
    getAuthenData w = testBit w 5
    getChkDisable w = testBit w 4

----------------------------------------------------------------

getHeader :: SGet DNSHeader
getHeader =
    DNSHeader <$> decodeIdentifier <*> getDNSFlags
  where
    decodeIdentifier = get16

----------------------------------------------------------------

getQueries :: Int -> SGet [Question]
getQueries n = replicateM n getQuery

getTYPE :: SGet TYPE
getTYPE = toTYPE <$> get16

-- XXX: Include the class when implemented, or otherwise perhaps check the
-- implicit assumption that the class is classIN.
--
getQuery :: SGet Question
getQuery = Question <$> getDomain
                    <*> getTYPE
                    <*  ignoreClass
  where
    ignoreClass = get16

getResourceRecords :: Int -> SGet [ResourceRecord]
getResourceRecords n = replicateM n getResourceRecord

getResourceRecord :: SGet ResourceRecord
getResourceRecord = do
    dom <- getDomain
    typ <- getTYPE
    cls <- get16
    ttl <- get32
    len <- getInt16
    dat <- fitSGet len $ getRData typ len
    return $ ResourceRecord dom typ cls ttl dat

----------------------------------------------------------------

-- | Helper to find position of RData end, that is, the offset of the first
-- byte /after/ the current RData.
--
rdataEnd :: Int      -- ^ number of bytes left from current position
         -> SGet Int -- ^ end position
rdataEnd !len = (+) len <$> getPosition

getRData :: TYPE -> Int -> SGet RData
getRData NS _    = RD_NS    <$> getDomain
getRData MX _    = RD_MX    <$> get16 <*> getDomain
getRData CNAME _ = RD_CNAME <$> getDomain
getRData DNAME _ = RD_DNAME <$> getDomain
getRData TXT len = RD_TXT   <$> getTXT len
getRData A _     = RD_A . toIPv4 <$> getNBytes 4
getRData AAAA _  = RD_AAAA . toIPv6b <$> getNBytes 16
getRData SOA _   = RD_SOA  <$> getDomain
                           <*> getMailbox
                           <*> decodeSerial
                           <*> decodeRefesh
                           <*> decodeRetry
                           <*> decodeExpire
                           <*> decodeMinimum
  where
    decodeSerial  = get32
    decodeRefesh  = get32
    decodeRetry   = get32
    decodeExpire  = get32
    decodeMinimum = get32
getRData PTR _ = RD_PTR <$> getDomain
getRData SRV _ = RD_SRV <$> decodePriority
                        <*> decodeWeight
                        <*> decodePort
                        <*> getDomain
  where
    decodePriority = get16
    decodeWeight   = get16
    decodePort     = get16
--
getRData RP _   = RD_RP <$> getMailbox
                        <*> getDomain
--
getRData OPT len   = RD_OPT <$> getOpts len
--
getRData TLSA len = RD_TLSA <$> decodeUsage
                            <*> decodeSelector
                            <*> decodeMType
                            <*> decodeADF
  where
    decodeUsage    = get8
    decodeSelector = get8
    decodeMType    = get8
    decodeADF      = getNByteString (len - 3)
--
getRData DS len = RD_DS <$> decodeTag
                        <*> decodeAlg
                        <*> decodeDtyp
                        <*> decodeDval
  where
    decodeTag  = get16
    decodeAlg  = get8
    decodeDtyp = get8
    decodeDval = getNByteString (len - 4)
--
getRData CDS len = RD_CDS <$> decodeTag
                          <*> decodeAlg
                          <*> decodeDtyp
                          <*> decodeDval
  where
    decodeTag  = get16
    decodeAlg  = get8
    decodeDtyp = get8
    decodeDval = getNByteString (len - 4)
--
getRData RRSIG len = RD_RRSIG <$> decodeRRSIG
  where
    decodeRRSIG = do
        -- The signature follows a variable length zone name
        -- and occupies the rest of the RData.  Simplest to
        -- checkpoint the position at the start of the RData,
        -- and after reading the zone name, and subtract that
        -- from the RData length.
        --
        end <- rdataEnd len
        typ <- getTYPE
        alg <- get8
        cnt <- get8
        ttl <- get32
        tex <- getDnsTime
        tin <- getDnsTime
        tag <- get16
        dom <- getDomain -- XXX: Enforce no compression?
        pos <- getPosition
        val <- getNByteString $ end - pos
        return $ RDREP_RRSIG typ alg cnt ttl tex tin tag dom val
    getDnsTime   = do
        tnow <- getAtTime
        tdns <- get32
        return $! dnsTime tdns tnow
--
getRData NULL len = RD_NULL <$> getNByteString len
getRData NSEC len = do
    end <- rdataEnd len
    dom <- getDomain
    pos <- getPosition
    RD_NSEC dom <$> getNsecTypes (end - pos)
--
getRData DNSKEY len = RD_DNSKEY <$> decodeKeyFlags
                                <*> decodeKeyProto
                                <*> decodeKeyAlg
                                <*> decodeKeyBytes
  where
    decodeKeyFlags  = get16
    decodeKeyProto  = get8
    decodeKeyAlg    = get8
    decodeKeyBytes  = getNByteString (len - 4)
--
getRData CDNSKEY len = RD_CDNSKEY <$> decodeKeyFlags
                                  <*> decodeKeyProto
                                  <*> decodeKeyAlg
                                  <*> decodeKeyBytes
  where
    decodeKeyFlags  = get16
    decodeKeyProto  = get8
    decodeKeyAlg    = get8
    decodeKeyBytes  = getNByteString (len - 4)
--
getRData NSEC3 len = do
    dend <- rdataEnd len
    halg <- get8
    flgs <- get8
    iter <- get16
    salt <- getInt8 >>= getNByteString
    hash <- getInt8 >>= getNByteString
    tpos <- getPosition
    RD_NSEC3 halg flgs iter salt hash <$> getNsecTypes (dend - tpos)
--
getRData NSEC3PARAM _ = RD_NSEC3PARAM <$> decodeHashAlg
                                      <*> decodeFlags
                                      <*> decodeIterations
                                      <*> decodeSalt
  where
    decodeHashAlg    = get8
    decodeFlags      = get8
    decodeIterations = get16
    decodeSalt       = getInt8 >>= getNByteString
--
getRData CAA len = do
    dend <- rdataEnd len
    flags <- get8
    tag <- getInt8 >>= getNByteString
    tpos <- getPosition
    RD_CAA flags (CI.mk tag) <$> getNByteString (dend - tpos)
--
getRData _  len = UnknownRData <$> getNByteString len

----------------------------------------------------------------

-- $
--
-- >>> import Network.DNS.StateBinary
-- >>> let Right ((t,_),l) = runSGetWithLeftovers (getTXT 8) "\3foo\3barbaz"
-- >>> (t, l) == ("foobar", "baz")
-- True

-- | Concatenate a sequence of length-prefixed strings of text
-- https://tools.ietf.org/html/rfc1035#section-3.3
--
getTXT :: Int -> SGet ByteString
getTXT !len = B.concat <$> sGetMany "TXT RR string" len getstring
  where
    getstring = getInt8 >>= getNByteString

-- <https://tools.ietf.org/html/rfc6891#section-6.1.2>
-- Parse a list of EDNS options
--
getOpts :: Int -> SGet [OData]
getOpts !len = sGetMany "EDNS option" len getoption
  where
    getoption = do
        code <- toOptCode <$> get16
        olen <- getInt16
        getOData code olen

-- <https://tools.ietf.org/html/rfc4034#section-4.1>
-- Parse a list of NSEC type bitmaps
--
getNsecTypes :: Int -> SGet [TYPE]
getNsecTypes !len = concat <$> sGetMany "NSEC type bitmap" len getbits
  where
    getbits = do
        window <- flip shiftL 8 <$> getInt8
        blocks <- getInt8
        when (blocks > 32) $
            failSGet $ "NSEC bitmap block too long: " ++ show blocks
        concatMap blkTypes. zip [window, window + 8..] <$> getNBytes blocks
      where
        blkTypes (bitOffset, byte) =
            [ toTYPE $ fromIntegral $ bitOffset + i |
              i <- [0..7], byte .&. bit (7-i) /= 0 ]

----------------------------------------------------------------

getOData :: OptCode -> Int -> SGet OData
getOData NSID len = OD_NSID <$> getNByteString len
getOData DAU  len = OD_DAU  <$> getNoctets len
getOData DHU  len = OD_DHU  <$> getNoctets len
getOData N3U  len = OD_N3U  <$> getNoctets len
getOData ClientSubnet len = do
        family  <- get16
        srcBits <- get8
        scpBits <- get8
        addrbs  <- getNByteString (len - 4) -- 4 = 2 + 1 + 1
        --
        -- https://tools.ietf.org/html/rfc7871#section-6
        --
        -- o  ADDRESS, variable number of octets, contains either an IPv4 or
        --    IPv6 address, depending on FAMILY, which MUST be truncated to the
        --    number of bits indicated by the SOURCE PREFIX-LENGTH field,
        --    padding with 0 bits to pad to the end of the last octet needed.
        --
        -- o  A server receiving an ECS option that uses either too few or too
        --    many ADDRESS octets, or that has non-zero ADDRESS bits set beyond
        --    SOURCE PREFIX-LENGTH, SHOULD return FORMERR to reject the packet,
        --    as a signal to the software developer making the request to fix
        --    their implementation.
        --
        -- In order to avoid needless decoding errors, when the ECS encoding
        -- requirements are violated, we construct an OD_ECSgeneric OData,
        -- instread of an IP-specific OD_ClientSubnet OData, which will only
        -- be used for valid inputs.  When the family is neither IPv4(1) nor
        -- IPv6(2), or the address prefix is not correctly encoded (too long
        -- or too short), the OD_ECSgeneric data contains the verbatim input
        -- from the peer.
        --
        case BS.length addrbs == (fromIntegral srcBits + 7) `div` 8 of
            True | Just ip <- bstoip family addrbs srcBits scpBits
                -> pure $ OD_ClientSubnet srcBits scpBits ip
            _   -> pure $ OD_ECSgeneric family srcBits scpBits addrbs
  where
    prefix addr bits = Data.IP.addr $ makeAddrRange addr $ fromIntegral bits
    zeropad = (++ repeat 0). map fromIntegral. B.unpack
    checkBits fromBytes toIP srcBits scpBits bytes =
        let addr       = fromBytes bytes
            maskedAddr = prefix addr srcBits
            maxBits    = fromIntegral $ 8 * length bytes
         in if addr == maskedAddr && scpBits <= maxBits
            then Just $ toIP addr
            else Nothing
    bstoip :: Word16 -> B.ByteString -> Word8 -> Word8 -> Maybe IP
    bstoip family bs srcBits scpBits = case family of
        1 -> checkBits toIPv4  IPv4 srcBits scpBits $ take 4  $ zeropad bs
        2 -> checkBits toIPv6b IPv6 srcBits scpBits $ take 16 $ zeropad bs
        _ -> Nothing
getOData opc len = UnknownOData (fromOptCode opc) <$> getNByteString len

----------------------------------------------------------------

-- | Pointers MUST point back into the packet per RFC1035 Section 4.1.4.  This
-- is further interpreted by the DNS community (from a discussion on the IETF
-- DNSOP mailing list) to mean that they don't point back into the same domain.
-- Therefore, when starting to parse a domain, the current offset is also a
-- strict upper bound on the targets of any pointers that arise while processing
-- the domain.  When following a pointer, the target again becomes a stict upper
-- bound for any subsequent pointers.  This results in a simple loop-prevention
-- algorithm, each sequence of valid pointer values is necessarily strictly
-- decreasing!  The third argument to 'getDomain'' is a strict pointer upper
-- bound, and is set here to the position at the start of parsing the domain
-- or mailbox.
--
-- Note: the separator passed to 'getDomain'' is required to be either \'.\' or
-- \'\@\', or else 'unparseLabel' needs to be modified to handle the new value.
--

getDomain :: SGet Domain
getDomain = getPosition >>= getDomain' dot

getMailbox :: SGet Mailbox
getMailbox = getPosition >>= getDomain' atsign

dot, atsign :: Word8
dot    = fromIntegral $ fromEnum '.' -- 46
atsign = fromIntegral $ fromEnum '@' -- 64

-- $
-- Pathological case: pointer embedded inside a label!  The pointer points
-- behind the start of the domain and is then absorbed into the initial label!
-- Though we don't IMHO have to support this, it is not manifestly illegal, and
-- does exercise the code in an interesting way.  Ugly as this is, it also
-- "works" the same in Perl's Net::DNS and reportedly in ISC's BIND.
--
-- >>> :{
-- let input = "\6\3foo\192\0\3bar\0"
--     parser = skipNBytes 1 >> getDomain' dot 1
--     Right (output, _) = runSGet parser input
--  in output == "foo.\\003foo\\192\\000.bar."
-- :}
-- True
--
-- The case below fails to point far enough back, and triggers the loop
-- prevention code-path.
--
-- >>> :{
-- let input = "\6\3foo\192\1\3bar\0"
--     parser = skipNBytes 1 >> getDomain' dot 1
--     Left (DecodeError err) = runSGet parser input
--  in err
-- :}
-- "invalid name compression pointer"

-- | Get a domain name, using sep1 as the separator between the 1st and 2nd
-- label.  Subsequent labels (and always the trailing label) are terminated
-- with a ".".
--
-- Note: the separator is required to be either \'.\' or \'\@\', or else
-- 'unparseLabel' needs to be modified to handle the new value.
--
-- Domain name compression pointers must always refer to a position that
-- precedes the start of the current domain name.  The starting offsets form a
-- strictly decreasing sequence, which prevents pointer loops.
--
getDomain' :: Word8 -> Int -> SGet ByteString
getDomain' sep1 ptrLimit = do
    pos <- getPosition
    c <- getInt8
    let n = getValue c
    getdomain pos c n
  where
    -- Reprocess the same ByteString starting at the pointer
    -- target (offset).
    getPtr pos offset = do
        msg <- getInput
        let parser = skipNBytes offset >> getDomain' sep1 offset
        case runSGet parser msg of
            Left (DecodeError err) -> failSGet err
            Left err               -> fail $ show err
            Right o                -> do
                -- Cache only the presentation form decoding of domain names,
                -- mailboxes (e.g. SOA rname) are less frequently reused, and
                -- have a different presentation form, so must not share the
                -- same cache.
                when (sep1 == dot) $
                    push pos (fst o)
                return (fst o)

    getdomain pos c n
      | c == 0 = return "." -- Perhaps the root domain?
      | isPointer c = do
          d <- getInt8
          let offset = n * 256 + d
          when (offset >= ptrLimit) $
              failSGet "invalid name compression pointer"
          if sep1 /= dot
              then getPtr pos offset
              else pop offset >>= \case
                  Nothing -> getPtr pos offset
                  Just o  -> return o
      -- As for now, extended labels have no use.
      -- This may change some time in the future.
      | isExtLabel c = return ""
      | otherwise = do
          hs <- unparseLabel sep1 <$> getNByteString n
          ds <- getDomain' dot ptrLimit
          let dom = case ds of -- avoid trailing ".."
                  "." -> hs <> "."
                  _   -> hs <> B.singleton sep1 <> ds
          push pos dom
          return dom
    getValue c = c .&. 0x3f
    isPointer c = testBit c 7 && testBit c 6
    isExtLabel c = not (testBit c 7) && testBit c 6