File: PMComp.hs

package info (click to toggle)
kaya 0.4.2-4
  • links: PTS
  • area: main
  • in suites: lenny
  • size: 4,448 kB
  • ctags: 1,694
  • sloc: cpp: 9,536; haskell: 7,461; sh: 3,013; yacc: 910; makefile: 816; perl: 90
file content (419 lines) | stat: -rw-r--r-- 17,009 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
{-
    Kaya - My favourite toy language.
    Copyright (C) 2004-2007 Edwin Brady

    This file is distributed under the terms of the GNU General
    Public Licence. See COPYING for licence.
-}

-- Pattern matching compiler, generating simple case trees from
-- case on complex expressions. Based on Wadler's match compiler from
-- "The Implementation of Functional Programming Languages", 
-- Simon Peyton Jones, 1987

module PMComp where

import Language

import Control.Monad.State
import Debug.Trace
import List(group,sort,sortBy,transpose,nub)

mkCase :: String -> Int -> Context -> Name -> -- for finding constructor names
          Int -> -- first variable to introduce (we might need this several times per function, so names we introduce need to be unique!)
          Raw -> [MatchAlt] -> (Raw, Int)
mkCase f l ctxt mod name v ms 
    = let (tm, CS var) = runState (caseComp f l ctxt mod [v] (addCatchAll ms)) (CS name) in
          (tm, var)
  where addCatchAll [] = [MAlt f l [RUnderscore f l]
                          (RThrow f l (RApply f l (RQVar f l missingCase) []))]
        -- Next three cases identify that there is already a catch all case,
        -- so no need to add it.
        addCatchAll (m@(MAlt f l [RUnderscore _ _] res):ms) = m:ms
        addCatchAll (m@(MAlt f l [RQVar _ _ _] res):ms) = m:ms
        addCatchAll (m@(MAlt f l [RVar _ _ v] res):ms) 
            | isVarName mod ctxt v = m:ms
        addCatchAll (m:ms) = m:(addCatchAll ms)

-- Fail if the alternative is non-linear (i.e. has a repeated name)

checkLinear :: Monad m => Name -> Context -> MatchAlt -> m ()
-- Get all the names from ps; 
checkLinear mod ctxt (MAlt f l ps _) = do
   let allnames = (group.sort) (filter (isVarName mod ctxt) (concat (map getPvars ps)))
   checkUnique allnames
  where checkUnique [] = return ()
        checkUnique (x:xs) | length x == 1 = checkUnique xs
                           | otherwise = fail $ f ++ ":" ++ show l ++ 
                                              ":Name '" ++ 
                                              showuser (head x) ++ 
                                              "' is repeated in pattern"

-- Get the names used in a raw term representing a pattern

getPvars :: Raw -> [Name]
getPvars (RVar _ _ x) = [x]
getPvars (RQVar _ _ x) = [x]
getPvars (RApply f l fn args) = getPvars fn ++ concat (map getPvars args)
getPvars _ = []

data CaseState = CS { nextVar :: Int }

getVar :: State CaseState Name
getVar = do (CS var) <- get
            put (CS (var+1))
            return $ MN ("pv", var)

getNewVars :: Name -> Context -> [Raw] -> State CaseState [Name]
getNewVars mod ctxt [] = return []
getNewVars mod ctxt ((RQVar _ _ n):xs) = do rest <- getNewVars mod ctxt xs
                                            return (n:rest)
getNewVars mod ctxt ((RVar _ _ n):xs) 
    | isVarName mod ctxt n = do rest <- getNewVars mod ctxt xs
                                return (n:rest)
getNewVars mod ctxt (x:xs) = do v <- getVar
                                rest <- getNewVars mod ctxt xs
                                return (v:rest)

-- A scrutinee needs to be a variable for the algorithm to work. If it's
-- not, make a new variable for it. Return a triple of the term, the variable
-- to examine, and whether we made a new variable for it.

{-
getScrutinee :: Name -> Context -> Raw -> State CaseState (Raw, Name, Bool)
getScrutinee mod ctxt r@(RQVar _ _ n) = return (r, n, False)
getScrutinee mod ctxt r@(RVar _ _ n) 
    | isVarName mod ctxt n = return (r, n, False)
getScrutinee _ _ r = do v <- getVar
                        return (r, v, True)
-}

-- Take a list of arguments (rs) and a list of possible matches for those
-- arguments (ms) and convert into a simple case tree.

caseComp :: String -> Int -> Context -> Name ->
            [Raw] -> [MatchAlt] -> 
            State CaseState Raw
caseComp f l ctxt n rs ms = do
     let (rs', ms') = reorder n ctxt rs ms -- optimise ordering
     (CS firstv) <- get
--     rvs <- mapM (getScrutinee n ctxt) rs
     tm <- match f l ctxt n rs' ms' err
--     bn <- bindNames rvs tm
     -- Need to add explicit declarations for the names we introduced
     (CS lastv) <- get
     let dn = declareNames firstv lastv tm
     return dn
   where {- bindNames [] tm = return tm
         bindNames ((r,v, True):rvs) tm
             = bindNames rvs (RSeq f l (RAssign f l (RAName f l v) r) tm)
         bindNames ((r,v, False):rvs) tm
             = bindNames rvs tm -}

         declareNames firstv v tm 
             | v == firstv = tm
             | otherwise = RDeclare f l (MN ("pv", v-1),False) UnknownType
                                         (declareNames firstv (v-1) tm)
         err = RThrow f l (RApply f l (RQVar f l missingCase) [])

-- Order the arguments so that the one with the largest number of disjoint
-- cases is first. This is a greedy way of making the case tree branch less
-- to make the resulting code smaller. It's not optimal, but it's better than
-- a naive left to right ordering from the original source.

reorder :: Name -> Context -> [Raw] -> [MatchAlt] -> ([Raw], [MatchAlt])
reorder mod ctxt [x] ms = ([x], ms) -- don't waste time on it!
reorder mod ctxt raws ms 
    = let msp = transpose (map getPat ms)
          dist = zip [0..] (map (distinct []) msp)
          -- dist now tells us how many distinct constructors each argument
          -- position has. So let's sort them according to that number and
          -- choose the argument order
          argOrder = map fst $ reverse $ sortBy ordDist dist in
      -- Pick out the arguments in the right order
      (orderList argOrder raws, orderAll argOrder ms)

   where getPat (MAlt f l ps r) = ps
         distinct cs [] = length (nub cs)
         distinct cs (r:rest) = distinct (addCon r cs) rest
         ordDist (n,x) (n',y) = compare x y

         addCon (RVar _ _ n) cs | isVarName mod ctxt n = (CName n):cs
         addCon (RVar _ _ n) cs | isVarName mod ctxt n = (CName n):cs
         addCon (RApply _ _ r _) cs = addCon r cs
         addCon (RConst _ _ c) cs = (CConst c):cs
         addCon (RArrayInit _ _ a) cs = (CArray (length a)):cs
         addCon _ cs = cs

         orderAll args [] = []
         orderAll args ((MAlt f l ps res):ms) 
             = (MAlt f l (orderList args ps) res):(orderAll args ms)

         orderList [] _ = []
         orderList (x:xs) zs = (zs!!x):(orderList xs zs)


-- The match alternatives could either be all variables, all constructors,
-- or a mixture.

match :: String -> Int -> Context -> Name ->
         [Raw] -> [MatchAlt] -> Raw ->
         State CaseState Raw
match _ _ ctxt n [] ((MAlt f l [] res):_) err = return res
match f l ctxt n vs ms err = 
         -- optimise ordering of arguments
         let (vs',ms') = reorder n ctxt vs ms
             ps = partition n ctxt ms' in
         mixture f l ctxt n vs' ps err

-- Mixture rule applies the Variable and Constructor rules in order
-- as appropriate. For each one, compute what to do (the fallthrough) if
-- the given partition fails to match (i.e. what the remaining partitions
-- compile to) then run the constructor rule or variable rule.

mixture :: String -> Int -> Context -> Name -> [Raw] ->
             [Partition] -> Raw -> State CaseState Raw
mixture f l ctxt n vs [] err = return err
mixture f l ctxt n vs ((Cons ms):ps) err 
    = do fallthrough <- (mixture f l ctxt n vs ps err)
         conRule f l ctxt n vs ms fallthrough
mixture f l ctxt n vs ((Vars ms):ps) err 
    = do fallthrough <- (mixture f l ctxt n vs ps err)
         varRule f l ctxt n vs ms fallthrough

data Partition = Cons [MatchAlt]
               | Vars [MatchAlt]
  deriving Show

partition :: Name -> Context -> [MatchAlt] -> [Partition]
partition mod ctxt [] = []
partition mod ctxt ms@(m:_) 
    | isVar mod ctxt m = let (vars, rest) = span (isVar mod ctxt) ms 
                             prest = partition mod ctxt rest in
                             checkOverlap mod ctxt vars prest
    | isCon mod ctxt m = let (cons, rest) = span (isCon mod ctxt) ms in
                            (Cons cons):(partition mod ctxt rest)
partition mod ctxt x = error (show x)

checkOverlap mod ctxt vs [] = [Vars vs]
-- If prest is non-empty and anything in 'vs' is purely variable patterns,
-- then there will be some overlap, so give a warning.
checkOverlap mod ctxt vs prest = co vs vs prest
  where co vs [] prest = (Vars vs):prest
        co vs ((MAlt f l ps r):ms) prest 
           | all varpat ps = trace (f ++ ":" ++ show l ++ ":Warning -- overlapping patterns") (Vars vs):prest
           | otherwise = co vs ms prest
        varpat (RQVar _ _ v) = isVarName mod ctxt v
        varpat (RVar _ _ v) = isVarName mod ctxt v
        varpat _ = False

allVars n c = all (isVar n c)
allCons n c = all (isCon n c)

isVar :: Name -> Context -> MatchAlt -> Bool
isVar mod ctxt (MAlt _ _ ((RVar _ _ v):_) _) 
   = isVarName mod ctxt v
isVar mod ctxt (MAlt _ _ ((RQVar _ _ v):_) _) 
   = isVarName mod ctxt v
isVar mod ctxt (MAlt _ _ ((RUnderscore _ _):_) _) = True
isVar _ _ _ = False

isVarName mod ctxt v
     = all nonConstructor (lookupname mod v ctxt)
   where nonConstructor (n,(ty,opts)) = not (Constructor `elem` opts)

isConName mod ctxt v
     = any constructor (lookupname mod v ctxt)
   where constructor (n,(ty,opts)) = Constructor `elem` opts

isCon :: Name -> Context -> MatchAlt -> Bool
isCon mod ctxt (MAlt _ _ ((RApply _ _ (RVar _ _ con) args):_) _) = True
isCon mod ctxt (MAlt _ _ ((RApply _ _ (RQVar _ _ con) args):_) _) = True
isCon mod ctxt (MAlt _ _ ((RVar _ _ v):_) _) 
   = isConName mod ctxt v
isCon mod ctxt (MAlt _ _ ((RQVar _ _ v):_) _) 
   = isConName mod ctxt v
isCon mod ctxt (MAlt _ _ ((RConst _ _ _):_) _) = True
isCon mod ctxt (MAlt _ _ ((RArrayInit _ _ _):_) _) = True
isCon _ _ _ = False
                         
varRule :: String -> Int -> Context -> Name ->
           [Raw] -> [MatchAlt] -> Raw ->
           State CaseState Raw
varRule f l c n (v:vs) alts err = do
    let alts' = map (repVar v) alts
    match f l c n vs alts' err
  where repVar v (MAlt _ _ ((RVar f l n):ps) res) =
            -- replace n with v in res
            MAlt f l ps (rawSubst n v res)
        repVar v (MAlt _ _ ((RQVar f l n):ps) res) =
            -- replace n with v in res
            MAlt f l ps (rawSubst n v res)
        repVar v (MAlt _ _ ((RUnderscore f l):ps) res) =
            MAlt f l ps res

-- The things we treat as constructors
data ConType = CName Name -- ordinary named constructor
             | CConst Const -- constant pattern
             | CArray Int -- array pattern with length
   deriving (Show, Eq)

data Group = ConGroup ConType -- constructor
             -- arguments and rest of alternative for each instance
                   [([Raw], MatchAlt)] 
   deriving Show

conRule :: String -> Int -> Context -> Name ->
           [Raw] -> [MatchAlt] -> Raw ->
           State CaseState Raw
conRule f l ctxt mod (v:vs) alts err = 
  do groups <- groupCons alts ctxt mod
     caseGroups f l ctxt mod (v:vs) groups err

caseGroups :: String -> Int -> Context -> Name ->
              [Raw] -> [Group] -> Raw ->
              State CaseState Raw
caseGroups f l ctxt mod (v:vs) gs err 
    = do g <- altGroups gs
         return $ RCase f l v g
   where altGroups [] = return [RDefault f l err]
         altGroups ((ConGroup (CName n) args):cs) 
             = do g <- altGroup n args
                  rest <- altGroups cs
                  return (g:rest)
         altGroups ((ConGroup (CConst cval) args):cs)
             = do g <- altConstGroup cval args
                  rest <- altGroups cs
                  return (g:rest)
         altGroups ((ConGroup (CArray len) args):cs)
             = do g <- altArrayGroup len args
                  rest <- altGroups cs
                  return (g:rest)

         altGroup n gs
             = do (newArgs, nextMs) <- argsToAlt mod ctxt gs
                  matchMs <- match f l ctxt mod 
                                (map (RQVar f l) newArgs++vs) nextMs err
                  return $ RAlt f l n newArgs matchMs
         altConstGroup n gs
             = do (_, nextMs) <- argsToAlt mod ctxt gs
                  matchMs <- match f l ctxt mod vs nextMs err
                  return $ RConstAlt f l n matchMs
         altArrayGroup len gs
             = do (newArgs, nextMs) <- argsToAlt mod ctxt gs
                  matchMs <- match f l ctxt mod 
                                (map (RQVar f l) newArgs++vs) nextMs err
                  return $ RArrayAlt f l newArgs matchMs

argsToAlt :: Name -> Context -> 
             [([Raw], MatchAlt)] -> State CaseState ([Name], [MatchAlt])
argsToAlt mod ctxt [] = return ([],[])
argsToAlt mod ctxt rs@((r,m):_) 
    = do -- generate new argument names
         newArgs <- getNewVars mod ctxt r
         -- generate new match alternatives, by appending all the rs to m
         return (newArgs, addRs rs)
  where addRs [] = []
        addRs ((r,(MAlt f l ps res) ):rs)
            = (MAlt f l (r++ps) res):(addRs rs)

-- Divide the alternatives into Groups

groupCons :: Monad m => [MatchAlt] -> Context -> Name -> m [Group]
groupCons ms ctxt mod = gc [] ms
  where
   gc acc [] = return acc
   gc acc ((MAlt f l (p:ps) res):ms) = do
       acc' <- addGroup f l p ps res acc
       gc acc' ms

   addGroup f l p ps rval acc = case isPatt p ctxt mod of
       ConPatt con conargs -> return $ addg con conargs (MAlt f l ps rval) acc
       Constant cval -> return $ addConG cval (MAlt f l ps rval) acc
       ArrayPatt len conargs -> return $ addArrG len conargs (MAlt f l ps rval) acc
       pat -> fail $ f ++ ":" ++ show l ++ ":I don't understand this pattern " -- ++ show pat

--   addgAll var res [] = [OneVar var res]
--   addgAll var res (g@(ConGroup n cs):gs)
--       = (ConGroup n (cs ++ [([], res)

   addg con conargs res [] = [ConGroup (CName con) [(conargs, res)]]
   addg con conargs res (g@(ConGroup (CName n) cs):gs) 
      | con == n = (ConGroup (CName n) (cs ++ [(conargs, res)])):gs
      | otherwise = g:(addg con conargs res gs)

   addConG con res [] = [ConGroup (CConst con) [([], res)]]
   addConG con res (g@(ConGroup (CConst n) cs):gs) 
      | con == n = (ConGroup (CConst n) (cs ++ [([], res)])):gs
      | otherwise = g:(addConG con res gs)

   addArrG len conargs res [] = [ConGroup (CArray len) [(conargs, res)]]
   addArrG len conargs res (g@(ConGroup (CArray n) cs):gs) 
      | len == n = (ConGroup (CArray n) (cs ++ [(conargs, res)])):gs
      | otherwise = g:(addArrG len conargs res gs)

-- match f l ctxt mod [] _ = fail "Can't match with no scrutinee" 
-- match f l ctxt mod (e:[]) ms = do
--      (unsimple, groups) <- groupCons ms ctxt mod
--      trace (show groups) $
--       if (unsimple == 0) then return $ mkSimpleCase f l e groups
--         else do fail "unfinished"

{-
mkSimpleCase :: String -> Int -> Raw -> [Group] -> Raw
mkSimpleCase f l e gs = RCase f l e (map mkAlt gs)
   where mkAlt (Simple f l c args ret) = RAlt f l c args ret
         mkAlt _ = error "Can't happen PMComp mkAlt"
-}

data Patt = ConPatt Name [Raw]
          | VarPatt Name
          | Constant Const
          | ArrayPatt Int [Raw]
          | DefaultPatt
          | NoPatt
   deriving (Show, Eq)

pcons x xs = RApply "foo" 1 (RVar "foo" 1 (UN "Cons")) [x, xs]
pvar x = RVar "foo" 1 (UN x)
pnil = RVar "foo" 1 (UN "Nil")
pint n = RConst "foo" 1 (Num n)

testAlts = [MAlt "foo" 1 [pcons (pvar "x") (pcons (pvar "y") (pvar "ys"))] (pint 2),
            MAlt "foo" 1 [pcons (pvar "x") pnil] (pint 1),
            MAlt "foo" 1 [pnil] (pint 0)]

isPatt (RApply _ _ (RVar _ _ con) args) _ _ = ConPatt con args
isPatt (RApply _ _ (RQVar _ _ con) args) _ _ = ConPatt con args
isPatt (RVar _ _ v) ctxt mod = vPatt v ctxt mod
isPatt (RQVar _ _ v) ctxt mod = vPatt v ctxt mod
isPatt (RConst _ _ c) ctxt mod = Constant c
isPatt (RArrayInit _ _ ps) ctxt mod = ArrayPatt (length ps) ps
isPatt p _ _ = NoPatt

-- If v is in the context, it can't be a local variable so let's treat it
-- as a constructor pattern. If it isn't a constructor pattern this will
-- fail harmlessly (and get caught by the typechecker) so no need for anything
-- more fancy.
vPatt v ctxt mod = case ctxtlookup mod v ctxt Nothing [] of
     Just _ -> ConPatt v []
     _ -> VarPatt v

-- If all the groups are in simple case expression form, we've won. Each
-- group is guaranteed to begin with a different constructor, so no need
-- to worry there. Returns the number of groups still to deal with, and
-- the new groupings after simplification.

{-
simplify :: [Group] -> (Int, [Group])
simplify gs = simpl gs [] (length gs)
   where simpl [] acc i = (i, reverse acc)
         simpl ((ConGroup n [(args, MAlt f l [] res)]):gs) acc i
             | Just argnames <- mapM getName args 
                 = simpl gs ((Simple f l n argnames res):acc) (i-1)
         simpl (g:gs) acc i = simpl gs (g:acc) i

getName (RVar _ _ n) = Just n
getName (RQVar _ _ n) = Just n
getName _ = Nothing
-}