File: TypesTests.hs

package info (click to toggle)
haskell-futhark 0.25.32-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 18,236 kB
  • sloc: haskell: 100,484; ansic: 12,100; python: 3,440; yacc: 785; sh: 561; javascript: 558; lisp: 399; makefile: 277
file content (185 lines) | stat: -rw-r--r-- 6,437 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
module Language.Futhark.TypeChecker.TypesTests (tests) where

import Data.Bifunctor
import Data.List (isInfixOf)
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.FreshNames
import Futhark.Util.Pretty (docText, prettyTextOneLine)
import Language.Futhark
import Language.Futhark.Semantic
import Language.Futhark.SyntaxTests ()
import Language.Futhark.TypeChecker (initialEnv)
import Language.Futhark.TypeChecker.Monad
import Language.Futhark.TypeChecker.Names (resolveTypeExp)
import Language.Futhark.TypeChecker.Terms
import Language.Futhark.TypeChecker.Types
import Test.Tasty
import Test.Tasty.HUnit

evalTest :: TypeExp (ExpBase NoInfo Name) Name -> Either String ([VName], ResRetType) -> TestTree
evalTest te expected =
  testCase (prettyString te) $
    case (fmap (extract . fst) (run (checkTypeExp checkSizeExp =<< resolveTypeExp te)), expected) of
      (Left got_e, Left expected_e) ->
        let got_e_s = T.unpack $ docText $ prettyTypeError got_e
         in (expected_e `isInfixOf` got_e_s) @? got_e_s
      (Left got_e, Right _) ->
        let got_e_s = T.unpack $ docText $ prettyTypeError got_e
         in assertFailure $ "Failed: " <> got_e_s
      (Right actual_t, Right expected_t) ->
        actual_t @?= expected_t
      (Right actual_t, Left _) ->
        assertFailure $ "Expected error, got: " <> show actual_t
  where
    extract (_, svars, t, _) = (svars, t)
    run = snd . runTypeM env mempty (mkInitialImport "") (newNameSource 100)
    -- We hack up an environment with some predefined type
    -- abbreviations for testing.  This is all pretty sensitive to the
    -- specific unique names, so we have to be careful!
    env =
      initialEnv
        { envTypeTable =
            M.fromList
              [ ( "square_1000",
                  TypeAbbr
                    Unlifted
                    [TypeParamDim "n_1001" mempty]
                    "[n_1001][n_1001]i32"
                ),
                ( "fun_1100",
                  TypeAbbr
                    Lifted
                    [ TypeParamType Lifted "a_1101" mempty,
                      TypeParamType Lifted "b_1102" mempty
                    ]
                    "a_1101 -> b_1102"
                ),
                ( "pair_1200",
                  TypeAbbr
                    SizeLifted
                    []
                    "?[n_1201][m_1202].([n_1201]i64, [m_1202]i64)"
                )
              ]
              <> envTypeTable initialEnv,
          envNameMap =
            M.fromList
              [ ((Type, "square"), "square_1000"),
                ((Type, "fun"), "fun_1100"),
                ((Type, "pair"), "pair_1200")
              ]
              <> envNameMap initialEnv
        }

evalTests :: TestTree
evalTests =
  testGroup
    "Type expression elaboration"
    [ testGroup "Positive tests" (map mkPos pos),
      testGroup "Negative tests" (map mkNeg neg)
    ]
  where
    mkPos (x, y) = evalTest x (Right y)
    mkNeg (x, y) = evalTest x (Left y)
    pos =
      [ ( "[]i32",
          ([], "?[d_100].[d_100]i32")
        ),
        ( "[][]i32",
          ([], "?[d_100][d_101].[d_100][d_101]i32")
        ),
        ( "bool -> []i32",
          ([], "bool -> ?[d_100].[d_100]i32")
        ),
        ( "bool -> []f32 -> []i32",
          (["d_100"], "bool -> [d_100]f32 -> ?[d_101].[d_101]i32")
        ),
        ( "([]i32,[]i32)",
          ([], "?[d_100][d_101].([d_100]i32, [d_101]i32)")
        ),
        ( "{a:[]i32,b:[]i32}",
          ([], "?[d_100][d_101].{a:[d_100]i32, b:[d_101]i32}")
        ),
        ( "?[n].[n][n]bool",
          ([], "?[n_100].[n_100][n_100]bool")
        ),
        ( "([]i32 -> []i32) -> bool -> []i32",
          (["d_100"], "([d_100]i32 -> ?[d_101].[d_101]i32) -> bool -> ?[d_102].[d_102]i32")
        ),
        ( "((k: i64) -> [k]i32 -> [k]i32) -> []i32 -> bool",
          (["d_101"], "((k_100: i64) -> [k_100]i32 -> [k_100]i32) -> [d_101]i32 -> bool")
        ),
        ( "square [10]",
          ([], "[10][10]i32")
        ),
        ( "square []",
          ([], "?[d_100].[d_100][d_100]i32")
        ),
        ( "bool -> square []",
          ([], "bool -> ?[d_100].[d_100][d_100]i32")
        ),
        ( "(k: i64) -> square [k]",
          ([], "(k_100: i64) -> [k_100][k_100]i32")
        ),
        ( "fun i32 bool",
          ([], "i32 -> bool")
        ),
        ( "fun ([]i32) bool",
          ([], "?[d_100].[d_100]i32 -> bool")
        ),
        ( "fun bool ([]i32)",
          ([], "?[d_100].bool -> [d_100]i32")
        ),
        ( "bool -> fun ([]i32) bool",
          ([], "bool -> ?[d_100].[d_100]i32 -> bool")
        ),
        ( "bool -> fun bool ([]i32)",
          ([], "bool -> ?[d_100].bool -> [d_100]i32")
        ),
        ( "pair",
          ([], "?[n_100][m_101].([n_100]i64, [m_101]i64)")
        ),
        ( "(pair,pair)",
          ([], "?[n_100][m_101][n_102][m_103].(([n_100]i64, [m_101]i64), ([n_102]i64, [m_103]i64))")
        )
      ]
    neg =
      [ ("?[n].bool", "Existential size \"n\""),
        ("?[n].bool -> [n]bool", "Existential size \"n\""),
        ("?[n].[n]bool -> [n]bool", "Existential size \"n\""),
        ("?[n].[n]bool -> bool", "Existential size \"n\"")
      ]

substTest :: M.Map VName (Subst StructRetType) -> StructRetType -> StructRetType -> TestTree
substTest m t expected =
  testCase (pretty_m <> ": " <> T.unpack (prettyTextOneLine t)) $
    applySubst (`M.lookup` m) t @?= expected
  where
    pretty_m = T.unpack $ prettyText $ map (first toName) $ M.toList m

-- Some of these tests may be a bit fragile, in that they depend on
-- internal renumbering, which can be arbitrary.
substTests :: TestTree
substTests =
  testGroup
    "Type substitution"
    [ substTest m0 "t_0" "i64",
      substTest m0 "[1]t_0" "[1]i64",
      substTest m0 "?[n_10].[n_10]t_0" "?[n_10].[n_10]i64",
      --
      substTest m1 "t_0" "?[n_1].[n_1]bool",
      substTest m1 "f32 -> t_0" "f32 -> ?[n_1].[n_1]bool",
      substTest m1 "f32 -> f64 -> t_0" "f32 -> f64 -> ?[n_1].[n_1]bool",
      substTest m1 "f32 -> t_0 -> bool" "?[n_1].f32 -> [n_1]bool -> bool",
      substTest m1 "f32 -> t_0 -> t_0" "?[n_1].f32 -> [n_1]bool -> ?[n_2].[n_2]bool"
    ]
  where
    m0 =
      M.fromList [("t_0", Subst [] "i64")]

    m1 =
      M.fromList [("t_0", Subst [] "?[n_1].[n_1]bool")]

tests :: TestTree
tests = testGroup "Basic type operations" [evalTests, substTests]