1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
|
{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Tests.Term (
Term
, serialise
, deserialise
, toRefTerm
, fromRefTerm
, eqTerm
, canonicaliseTerm
, prop_fromToRefTerm
, prop_toFromRefTerm
) where
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Text as T
import qualified Data.Text.Lazy as LT
import Data.Word
import qualified Numeric.Half as Half
import Codec.CBOR.Term
import Codec.CBOR.Read
import Codec.CBOR.Write
import Test.QuickCheck
import qualified Tests.Reference.Implementation as Ref
import Tests.Reference.Generators
( floatToWord, doubleToWord, canonicalNaN
, HalfSpecials(..), FloatSpecials(..), DoubleSpecials(..) )
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif
import Control.Exception (throw)
------------------------------------------------------------------------------
serialise :: Term -> LBS.ByteString
serialise = toLazyByteString . encodeTerm
deserialise :: LBS.ByteString -> Term
deserialise b =
case deserialiseFromBytes decodeTerm b of
Left failure -> throw failure
Right (trailing, _) | not (LBS.null trailing)
-> error "Test.deserialise: trailing data"
Right (_, t) -> t
------------------------------------------------------------------------------
toRefTerm :: Term -> Ref.Term
toRefTerm (TInt n)
| n >= 0 = Ref.TUInt (Ref.toUInt (fromIntegral n))
| otherwise = Ref.TNInt (Ref.toUInt (fromIntegral (-1 - n)))
toRefTerm (TInteger n) -- = Ref.TBigInt n
| n >= 0 && n <= fromIntegral (maxBound :: Word64)
= Ref.TUInt (Ref.toUInt (fromIntegral n))
| n < 0 && n >= -1 - fromIntegral (maxBound :: Word64)
= Ref.TNInt (Ref.toUInt (fromIntegral (-1 - n)))
| otherwise = Ref.TBigInt n
toRefTerm (TBytes bs) = Ref.TBytes (BS.unpack bs)
toRefTerm (TBytesI bs) = Ref.TBytess (map BS.unpack (LBS.toChunks bs))
toRefTerm (TString st) = Ref.TString (T.unpack st)
toRefTerm (TStringI st) = Ref.TStrings (map T.unpack (LT.toChunks st))
toRefTerm (TList ts) = Ref.TArray (map toRefTerm ts)
toRefTerm (TListI ts) = Ref.TArrayI (map toRefTerm ts)
toRefTerm (TMap ts) = Ref.TMap [ (toRefTerm x, toRefTerm y)
| (x,y) <- ts ]
toRefTerm (TMapI ts) = Ref.TMapI [ (toRefTerm x, toRefTerm y)
| (x,y) <- ts ]
toRefTerm (TTagged w t) = Ref.TTagged (Ref.toUInt (fromIntegral w))
(toRefTerm t)
toRefTerm (TBool False) = Ref.TFalse
toRefTerm (TBool True) = Ref.TTrue
toRefTerm TNull = Ref.TNull
toRefTerm (TSimple 23) = Ref.TUndef
toRefTerm (TSimple w) = Ref.TSimple (Ref.toSimple w)
toRefTerm (THalf f) = if isNaN f
then Ref.TFloat16 canonicalNaN
else Ref.TFloat16 (HalfSpecials (Half.toHalf f))
toRefTerm (TFloat f) = if isNaN f
then Ref.TFloat16 canonicalNaN
else Ref.TFloat32 (FloatSpecials f)
toRefTerm (TDouble f) = if isNaN f
then Ref.TFloat16 canonicalNaN
else Ref.TFloat64 (DoubleSpecials f)
fromRefTerm :: Ref.Term -> Term
fromRefTerm (Ref.TUInt u)
| n <= fromIntegral (maxBound :: Int) = TInt (fromIntegral n)
| otherwise = TInteger (fromIntegral n)
where n = Ref.fromUInt u
fromRefTerm (Ref.TNInt u)
| n <= fromIntegral (maxBound :: Int) = TInt (-1 - fromIntegral n)
| otherwise = TInteger (-1 - fromIntegral n)
where n = Ref.fromUInt u
fromRefTerm (Ref.TBigInt n) = TInteger n
fromRefTerm (Ref.TBytes bs) = TBytes (BS.pack bs)
fromRefTerm (Ref.TBytess bs) = TBytesI (LBS.fromChunks (map BS.pack bs))
fromRefTerm (Ref.TString st) = TString (T.pack st)
fromRefTerm (Ref.TStrings st) = TStringI (LT.fromChunks (map T.pack st))
fromRefTerm (Ref.TArray ts) = TList (map fromRefTerm ts)
fromRefTerm (Ref.TArrayI ts) = TListI (map fromRefTerm ts)
fromRefTerm (Ref.TMap ts) = TMap [ (fromRefTerm x, fromRefTerm y)
| (x,y) <- ts ]
fromRefTerm (Ref.TMapI ts) = TMapI [ (fromRefTerm x, fromRefTerm y)
| (x,y) <- ts ]
fromRefTerm (Ref.TTagged w t) = TTagged (Ref.fromUInt w)
(fromRefTerm t)
fromRefTerm (Ref.TFalse) = TBool False
fromRefTerm (Ref.TTrue) = TBool True
fromRefTerm Ref.TNull = TNull
fromRefTerm Ref.TUndef = TSimple 23
fromRefTerm (Ref.TSimple w) = TSimple (Ref.fromSimple w)
fromRefTerm (Ref.TFloat16 f) = THalf (Half.fromHalf (getHalfSpecials f))
fromRefTerm (Ref.TFloat32 f) = TFloat (getFloatSpecials f)
fromRefTerm (Ref.TFloat64 f) = TDouble (getDoubleSpecials f)
-- | Compare terms for equality.
--
-- It does exact bit for bit equality of floats. This means we can compare
-- NaNs, and different NaNs do not compare equal. If you need equality
-- modulo different NaNs then use 'canonicaliseTerm'.
--
-- If you need equality modulo different representations of 'TInt' vs 'TInteger'
-- then use 'canonicaliseTerm'.
--
eqTerm :: Term -> Term -> Bool
eqTerm (TList ts) (TList ts') = and (zipWith eqTerm ts ts')
eqTerm (TListI ts) (TListI ts') = and (zipWith eqTerm ts ts')
eqTerm (TMap ts) (TMap ts') = and (zipWith eqTermPair ts ts')
eqTerm (TMapI ts) (TMapI ts') = and (zipWith eqTermPair ts ts')
eqTerm (TTagged w t) (TTagged w' t') = w == w' && eqTerm t t'
eqTerm (THalf f) (THalf f') = floatToWord f == floatToWord f'
eqTerm (TFloat f) (TFloat f') = floatToWord f == floatToWord f'
eqTerm (TDouble f) (TDouble f') = doubleToWord f == doubleToWord f'
eqTerm a b = a == b
eqTermPair :: (Term, Term) -> (Term, Term) -> Bool
eqTermPair (a,b) (a',b') = eqTerm a a' && eqTerm b b'
-- | Both 'toRefTerm' and the encoding \/ decoding round trip canonicalises
-- NaNs. So tests involving these often need this in combination with
-- comparing for exact equality using 'eqTerm'.
--
canonicaliseTerm :: Term -> Term
canonicaliseTerm (THalf f) | isNaN f = canonicalTermNaN
canonicaliseTerm (TFloat f) | isNaN f = canonicalTermNaN
canonicaliseTerm (TDouble f) | isNaN f = canonicalTermNaN
canonicaliseTerm (TInteger n) | n <= fromIntegral (maxBound :: Int)
, n >= fromIntegral (minBound :: Int)
= TInt (fromIntegral n)
canonicaliseTerm (TList ts) = TList (map canonicaliseTerm ts)
canonicaliseTerm (TListI ts) = TListI (map canonicaliseTerm ts)
canonicaliseTerm (TMap ts) = TMap (map canonicaliseTermPair ts)
canonicaliseTerm (TMapI ts) = TMapI (map canonicaliseTermPair ts)
canonicaliseTerm (TTagged tag t) = TTagged tag (canonicaliseTerm t)
canonicaliseTerm t = t
canonicalTermNaN :: Term
canonicalTermNaN = THalf canonicalNaN
canonicaliseTermPair :: (Term, Term) -> (Term, Term)
canonicaliseTermPair (a,b) =
(canonicaliseTerm a, canonicaliseTerm b)
prop_fromToRefTerm :: Ref.Term -> Bool
prop_fromToRefTerm term = toRefTerm (fromRefTerm term)
== Ref.canonicaliseTerm term
prop_toFromRefTerm :: Term -> Bool
prop_toFromRefTerm term = fromRefTerm (toRefTerm term)
`eqTerm` canonicaliseTerm term
instance Arbitrary Term where
arbitrary = fromRefTerm <$> arbitrary
shrink (TInt n) = [ TInt n' | n' <- shrink n ]
shrink (TInteger n) = [ TInteger n' | n' <- shrink n ]
shrink (TBytes ws) = [ TBytes (BS.pack ws') | ws' <- shrink (BS.unpack ws) ]
shrink (TBytesI wss) = [ TBytesI (LBS.fromChunks (map BS.pack wss'))
| wss' <- shrink (map BS.unpack (LBS.toChunks wss)) ]
shrink (TString cs) = [ TString (T.pack cs') | cs' <- shrink (T.unpack cs) ]
shrink (TStringI css) = [ TStringI (LT.fromChunks (map T.pack css'))
| css' <- shrink (map T.unpack (LT.toChunks css)) ]
shrink (TList xs@[x]) = x : [ TList xs' | xs' <- shrink xs ]
shrink (TList xs) = [ TList xs' | xs' <- shrink xs ]
shrink (TListI xs@[x]) = x : [ TListI xs' | xs' <- shrink xs ]
shrink (TListI xs) = [ TListI xs' | xs' <- shrink xs ]
shrink (TMap xys@[(x,y)]) = x : y : [ TMap xys' | xys' <- shrink xys ]
shrink (TMap xys) = [ TMap xys' | xys' <- shrink xys ]
shrink (TMapI xys@[(x,y)]) = x : y : [ TMapI xys' | xys' <- shrink xys ]
shrink (TMapI xys) = [ TMapI xys' | xys' <- shrink xys ]
shrink (TTagged w t) = t : [ TTagged w' t' | (w', t') <- shrink (w, t)
, not (Ref.reservedTag (fromIntegral w')) ]
shrink (TBool _) = []
shrink TNull = []
shrink (TSimple w) = [ TSimple w' | w' <- shrink w
, Ref.unassignedSimple w || w == 23 ]
shrink (THalf _f) = []
shrink (TFloat f) = [ TFloat f' | f' <- shrink f ]
shrink (TDouble f) = [ TDouble f' | f' <- shrink f ]
|