1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import qualified Control.Monad.Fail as Fail
import Data.Data (Data, cast, gfoldl)
import Data.Functor.Const
(Const (Const, getConst))
import qualified Language.Haskell.Exts as Exts
import qualified Language.Haskell.Exts.Extension as Extension
import qualified Language.Haskell.Exts.Parser as Parser
import Language.Haskell.Meta.Parse
import Language.Haskell.Meta.Syntax.Translate
import qualified Language.Haskell.TH as TH
import Test.HUnit (Assertion, (@?=))
import Test.Tasty
(TestTree, defaultMain, testGroup)
import Test.Tasty.HUnit (testCase)
type Test = TestTree
main :: IO ()
main = defaultMain (testGroup "unit" tests)
tests :: [Test]
tests = [ derivingClausesTest
, typeAppTest
, orderInTypeTuples
]
derivingClausesTest :: Test
derivingClausesTest = testCase "Deriving clauses preserved" $
roundTripDecls "data Foo = Foo deriving (A, B, C)"
orderInTypeTuples :: Test
orderInTypeTuples =
testCase "Ensure that type tuples reconstructed in proper order" $ do
expected @?= actual
where
expected :: [TH.TyLit]
expected = collectAll (toExp parsed)
actual = [TH.StrTyLit "a", TH.StrTyLit "b"]
parsed :: Exts.Exp Exts.SrcSpanInfo
parsed = case Exts.parseExpWithMode mode "foo @'(\"a\", \"b\")" of
Exts.ParseOk v -> v
e -> error $ show e
mode :: Exts.ParseMode
mode = Exts.defaultParseMode {
Exts.extensions = [
Exts.EnableExtension Exts.TypeApplications
, Exts.EnableExtension Exts.DataKinds
]
}
collectAll :: (Data a, Data b) => a -> [b]
collectAll = ($ []) . go
where
go :: forall a b. (Data a, Data b) => a -> [b] -> [b]
go = \x ->
case cast x of
Just x' -> (x' :)
Nothing -> getConst $ gfoldl ap (const $ Const id) x
where
ap :: Data x => Const ([b] -> [b]) (x -> y) -> x -> Const ([b] -> [b]) y
ap (Const acc) x = Const $ acc . go x
typeAppMode :: Exts.ParseMode
typeAppMode = Parser.defaultParseMode { Parser.extensions = [Extension.EnableExtension Extension.TypeApplications] }
typeAppTest :: Test
typeAppTest = testCase "Type app preserved" $
roundTripDeclsWithMode typeAppMode "tenStr = show @Int 10"
roundTripDecls :: String -> Assertion
roundTripDecls s = do
declsExts <- liftEither $ parseHsDecls s
declsExts' <- liftEither $ parseDecs s >>= parseHsDecls . TH.pprint
declsExts' @?= declsExts
roundTripDeclsWithMode :: Exts.ParseMode -> String -> Assertion
roundTripDeclsWithMode mode s = do
declsExts <- liftEither $ parseHsDeclsWithMode mode s
declsExts' <- liftEither $ parseDecsWithMode mode s >>= parseHsDeclsWithMode mode . TH.pprint
declsExts' @?= declsExts
liftEither :: Fail.MonadFail m => Either String a -> m a
liftEither = either fail return
|