1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
|
{-
Kaya - My favourite toy language.
Copyright (C) 2004-2007 Edwin Brady
This file is distributed under the terms of the GNU General
Public Licence. See COPYING for licence.
-}
-- Pattern matching compiler, generating simple case trees from
-- case on complex expressions. Based on Wadler's match compiler from
-- "The Implementation of Functional Programming Languages",
-- Simon Peyton Jones, 1987
module PMComp where
import Language
import Control.Monad.State
import Debug.Trace
import List(group,sort,sortBy,transpose,nub)
mkCase :: String -> Int -> Context -> Name -> -- for finding constructor names
Int -> -- first variable to introduce (we might need this several times per function, so names we introduce need to be unique!)
Raw -> [MatchAlt] -> (Raw, Int)
mkCase f l ctxt mod name v ms
= let (tm, CS var) = runState (caseComp f l ctxt mod [v] (addCatchAll ms)) (CS name) in
(tm, var)
where addCatchAll [] = [MAlt f l [RUnderscore f l]
(RThrow f l (RApply f l (RQVar f l missingCase) []))]
-- Next three cases identify that there is already a catch all case,
-- so no need to add it.
addCatchAll (m@(MAlt f l [RUnderscore _ _] res):ms) = m:ms
addCatchAll (m@(MAlt f l [RQVar _ _ _] res):ms) = m:ms
addCatchAll (m@(MAlt f l [RVar _ _ v] res):ms)
| isVarName mod ctxt v = m:ms
addCatchAll (m:ms) = m:(addCatchAll ms)
-- Fail if the alternative is non-linear (i.e. has a repeated name)
checkLinear :: Monad m => Name -> Context -> MatchAlt -> m ()
-- Get all the names from ps;
checkLinear mod ctxt (MAlt f l ps _) = do
let allnames = (group.sort) (filter (isVarName mod ctxt) (concat (map getPvars ps)))
checkUnique allnames
where checkUnique [] = return ()
checkUnique (x:xs) | length x == 1 = checkUnique xs
| otherwise = fail $ f ++ ":" ++ show l ++
":Name '" ++
showuser (head x) ++
"' is repeated in pattern"
-- Get the names used in a raw term representing a pattern
getPvars :: Raw -> [Name]
getPvars (RVar _ _ x) = [x]
getPvars (RQVar _ _ x) = [x]
getPvars (RApply f l fn args) = getPvars fn ++ concat (map getPvars args)
getPvars _ = []
data CaseState = CS { nextVar :: Int }
getVar :: State CaseState Name
getVar = do (CS var) <- get
put (CS (var+1))
return $ MN ("pv", var)
getNewVars :: Name -> Context -> [Raw] -> State CaseState [Name]
getNewVars mod ctxt [] = return []
getNewVars mod ctxt ((RQVar _ _ n):xs) = do rest <- getNewVars mod ctxt xs
return (n:rest)
getNewVars mod ctxt ((RVar _ _ n):xs)
| isVarName mod ctxt n = do rest <- getNewVars mod ctxt xs
return (n:rest)
getNewVars mod ctxt (x:xs) = do v <- getVar
rest <- getNewVars mod ctxt xs
return (v:rest)
-- A scrutinee needs to be a variable for the algorithm to work. If it's
-- not, make a new variable for it. Return a triple of the term, the variable
-- to examine, and whether we made a new variable for it.
{-
getScrutinee :: Name -> Context -> Raw -> State CaseState (Raw, Name, Bool)
getScrutinee mod ctxt r@(RQVar _ _ n) = return (r, n, False)
getScrutinee mod ctxt r@(RVar _ _ n)
| isVarName mod ctxt n = return (r, n, False)
getScrutinee _ _ r = do v <- getVar
return (r, v, True)
-}
-- Take a list of arguments (rs) and a list of possible matches for those
-- arguments (ms) and convert into a simple case tree.
caseComp :: String -> Int -> Context -> Name ->
[Raw] -> [MatchAlt] ->
State CaseState Raw
caseComp f l ctxt n rs ms = do
let (rs', ms') = reorder n ctxt rs ms -- optimise ordering
(CS firstv) <- get
-- rvs <- mapM (getScrutinee n ctxt) rs
tm <- match f l ctxt n rs' ms' err
-- bn <- bindNames rvs tm
-- Need to add explicit declarations for the names we introduced
(CS lastv) <- get
let dn = declareNames firstv lastv tm
return dn
where {- bindNames [] tm = return tm
bindNames ((r,v, True):rvs) tm
= bindNames rvs (RSeq f l (RAssign f l (RAName f l v) r) tm)
bindNames ((r,v, False):rvs) tm
= bindNames rvs tm -}
declareNames firstv v tm
| v == firstv = tm
| otherwise = RDeclare f l (MN ("pv", v-1),False) UnknownType
(declareNames firstv (v-1) tm)
err = RThrow f l (RApply f l (RQVar f l missingCase) [])
-- Order the arguments so that the one with the largest number of disjoint
-- cases is first. This is a greedy way of making the case tree branch less
-- to make the resulting code smaller. It's not optimal, but it's better than
-- a naive left to right ordering from the original source.
reorder :: Name -> Context -> [Raw] -> [MatchAlt] -> ([Raw], [MatchAlt])
reorder mod ctxt [x] ms = ([x], ms) -- don't waste time on it!
reorder mod ctxt raws ms
= let msp = transpose (map getPat ms)
dist = zip [0..] (map (distinct []) msp)
-- dist now tells us how many distinct constructors each argument
-- position has. So let's sort them according to that number and
-- choose the argument order
argOrder = map fst $ reverse $ sortBy ordDist dist in
-- Pick out the arguments in the right order
(orderList argOrder raws, orderAll argOrder ms)
where getPat (MAlt f l ps r) = ps
distinct cs [] = length (nub cs)
distinct cs (r:rest) = distinct (addCon r cs) rest
ordDist (n,x) (n',y) = compare x y
addCon (RVar _ _ n) cs | isVarName mod ctxt n = (CName n):cs
addCon (RVar _ _ n) cs | isVarName mod ctxt n = (CName n):cs
addCon (RApply _ _ r _) cs = addCon r cs
addCon (RConst _ _ c) cs = (CConst c):cs
addCon (RArrayInit _ _ a) cs = (CArray (length a)):cs
addCon _ cs = cs
orderAll args [] = []
orderAll args ((MAlt f l ps res):ms)
= (MAlt f l (orderList args ps) res):(orderAll args ms)
orderList [] _ = []
orderList (x:xs) zs = (zs!!x):(orderList xs zs)
-- The match alternatives could either be all variables, all constructors,
-- or a mixture.
match :: String -> Int -> Context -> Name ->
[Raw] -> [MatchAlt] -> Raw ->
State CaseState Raw
match _ _ ctxt n [] ((MAlt f l [] res):_) err = return res
match f l ctxt n vs ms err =
-- optimise ordering of arguments
let (vs',ms') = reorder n ctxt vs ms
ps = partition n ctxt ms' in
mixture f l ctxt n vs' ps err
-- Mixture rule applies the Variable and Constructor rules in order
-- as appropriate. For each one, compute what to do (the fallthrough) if
-- the given partition fails to match (i.e. what the remaining partitions
-- compile to) then run the constructor rule or variable rule.
mixture :: String -> Int -> Context -> Name -> [Raw] ->
[Partition] -> Raw -> State CaseState Raw
mixture f l ctxt n vs [] err = return err
mixture f l ctxt n vs ((Cons ms):ps) err
= do fallthrough <- (mixture f l ctxt n vs ps err)
conRule f l ctxt n vs ms fallthrough
mixture f l ctxt n vs ((Vars ms):ps) err
= do fallthrough <- (mixture f l ctxt n vs ps err)
varRule f l ctxt n vs ms fallthrough
data Partition = Cons [MatchAlt]
| Vars [MatchAlt]
deriving Show
partition :: Name -> Context -> [MatchAlt] -> [Partition]
partition mod ctxt [] = []
partition mod ctxt ms@(m:_)
| isVar mod ctxt m = let (vars, rest) = span (isVar mod ctxt) ms
prest = partition mod ctxt rest in
checkOverlap mod ctxt vars prest
| isCon mod ctxt m = let (cons, rest) = span (isCon mod ctxt) ms in
(Cons cons):(partition mod ctxt rest)
partition mod ctxt x = error (show x)
checkOverlap mod ctxt vs [] = [Vars vs]
-- If prest is non-empty and anything in 'vs' is purely variable patterns,
-- then there will be some overlap, so give a warning.
checkOverlap mod ctxt vs prest = co vs vs prest
where co vs [] prest = (Vars vs):prest
co vs ((MAlt f l ps r):ms) prest
| all varpat ps = trace (f ++ ":" ++ show l ++ ":Warning -- overlapping patterns") (Vars vs):prest
| otherwise = co vs ms prest
varpat (RQVar _ _ v) = isVarName mod ctxt v
varpat (RVar _ _ v) = isVarName mod ctxt v
varpat _ = False
allVars n c = all (isVar n c)
allCons n c = all (isCon n c)
isVar :: Name -> Context -> MatchAlt -> Bool
isVar mod ctxt (MAlt _ _ ((RVar _ _ v):_) _)
= isVarName mod ctxt v
isVar mod ctxt (MAlt _ _ ((RQVar _ _ v):_) _)
= isVarName mod ctxt v
isVar mod ctxt (MAlt _ _ ((RUnderscore _ _):_) _) = True
isVar _ _ _ = False
isVarName mod ctxt v
= all nonConstructor (lookupname mod v ctxt)
where nonConstructor (n,(ty,opts)) = not (Constructor `elem` opts)
isConName mod ctxt v
= any constructor (lookupname mod v ctxt)
where constructor (n,(ty,opts)) = Constructor `elem` opts
isCon :: Name -> Context -> MatchAlt -> Bool
isCon mod ctxt (MAlt _ _ ((RApply _ _ (RVar _ _ con) args):_) _) = True
isCon mod ctxt (MAlt _ _ ((RApply _ _ (RQVar _ _ con) args):_) _) = True
isCon mod ctxt (MAlt _ _ ((RVar _ _ v):_) _)
= isConName mod ctxt v
isCon mod ctxt (MAlt _ _ ((RQVar _ _ v):_) _)
= isConName mod ctxt v
isCon mod ctxt (MAlt _ _ ((RConst _ _ _):_) _) = True
isCon mod ctxt (MAlt _ _ ((RArrayInit _ _ _):_) _) = True
isCon _ _ _ = False
varRule :: String -> Int -> Context -> Name ->
[Raw] -> [MatchAlt] -> Raw ->
State CaseState Raw
varRule f l c n (v:vs) alts err = do
let alts' = map (repVar v) alts
match f l c n vs alts' err
where repVar v (MAlt _ _ ((RVar f l n):ps) res) =
-- replace n with v in res
MAlt f l ps (rawSubst n v res)
repVar v (MAlt _ _ ((RQVar f l n):ps) res) =
-- replace n with v in res
MAlt f l ps (rawSubst n v res)
repVar v (MAlt _ _ ((RUnderscore f l):ps) res) =
MAlt f l ps res
-- The things we treat as constructors
data ConType = CName Name -- ordinary named constructor
| CConst Const -- constant pattern
| CArray Int -- array pattern with length
deriving (Show, Eq)
data Group = ConGroup ConType -- constructor
-- arguments and rest of alternative for each instance
[([Raw], MatchAlt)]
deriving Show
conRule :: String -> Int -> Context -> Name ->
[Raw] -> [MatchAlt] -> Raw ->
State CaseState Raw
conRule f l ctxt mod (v:vs) alts err =
do groups <- groupCons alts ctxt mod
caseGroups f l ctxt mod (v:vs) groups err
caseGroups :: String -> Int -> Context -> Name ->
[Raw] -> [Group] -> Raw ->
State CaseState Raw
caseGroups f l ctxt mod (v:vs) gs err
= do g <- altGroups gs
return $ RCase f l v g
where altGroups [] = return [RDefault f l err]
altGroups ((ConGroup (CName n) args):cs)
= do g <- altGroup n args
rest <- altGroups cs
return (g:rest)
altGroups ((ConGroup (CConst cval) args):cs)
= do g <- altConstGroup cval args
rest <- altGroups cs
return (g:rest)
altGroups ((ConGroup (CArray len) args):cs)
= do g <- altArrayGroup len args
rest <- altGroups cs
return (g:rest)
altGroup n gs
= do (newArgs, nextMs) <- argsToAlt mod ctxt gs
matchMs <- match f l ctxt mod
(map (RQVar f l) newArgs++vs) nextMs err
return $ RAlt f l n newArgs matchMs
altConstGroup n gs
= do (_, nextMs) <- argsToAlt mod ctxt gs
matchMs <- match f l ctxt mod vs nextMs err
return $ RConstAlt f l n matchMs
altArrayGroup len gs
= do (newArgs, nextMs) <- argsToAlt mod ctxt gs
matchMs <- match f l ctxt mod
(map (RQVar f l) newArgs++vs) nextMs err
return $ RArrayAlt f l newArgs matchMs
argsToAlt :: Name -> Context ->
[([Raw], MatchAlt)] -> State CaseState ([Name], [MatchAlt])
argsToAlt mod ctxt [] = return ([],[])
argsToAlt mod ctxt rs@((r,m):_)
= do -- generate new argument names
newArgs <- getNewVars mod ctxt r
-- generate new match alternatives, by appending all the rs to m
return (newArgs, addRs rs)
where addRs [] = []
addRs ((r,(MAlt f l ps res) ):rs)
= (MAlt f l (r++ps) res):(addRs rs)
-- Divide the alternatives into Groups
groupCons :: Monad m => [MatchAlt] -> Context -> Name -> m [Group]
groupCons ms ctxt mod = gc [] ms
where
gc acc [] = return acc
gc acc ((MAlt f l (p:ps) res):ms) = do
acc' <- addGroup f l p ps res acc
gc acc' ms
addGroup f l p ps rval acc = case isPatt p ctxt mod of
ConPatt con conargs -> return $ addg con conargs (MAlt f l ps rval) acc
Constant cval -> return $ addConG cval (MAlt f l ps rval) acc
ArrayPatt len conargs -> return $ addArrG len conargs (MAlt f l ps rval) acc
pat -> fail $ f ++ ":" ++ show l ++ ":I don't understand this pattern " -- ++ show pat
-- addgAll var res [] = [OneVar var res]
-- addgAll var res (g@(ConGroup n cs):gs)
-- = (ConGroup n (cs ++ [([], res)
addg con conargs res [] = [ConGroup (CName con) [(conargs, res)]]
addg con conargs res (g@(ConGroup (CName n) cs):gs)
| con == n = (ConGroup (CName n) (cs ++ [(conargs, res)])):gs
| otherwise = g:(addg con conargs res gs)
addConG con res [] = [ConGroup (CConst con) [([], res)]]
addConG con res (g@(ConGroup (CConst n) cs):gs)
| con == n = (ConGroup (CConst n) (cs ++ [([], res)])):gs
| otherwise = g:(addConG con res gs)
addArrG len conargs res [] = [ConGroup (CArray len) [(conargs, res)]]
addArrG len conargs res (g@(ConGroup (CArray n) cs):gs)
| len == n = (ConGroup (CArray n) (cs ++ [(conargs, res)])):gs
| otherwise = g:(addArrG len conargs res gs)
-- match f l ctxt mod [] _ = fail "Can't match with no scrutinee"
-- match f l ctxt mod (e:[]) ms = do
-- (unsimple, groups) <- groupCons ms ctxt mod
-- trace (show groups) $
-- if (unsimple == 0) then return $ mkSimpleCase f l e groups
-- else do fail "unfinished"
{-
mkSimpleCase :: String -> Int -> Raw -> [Group] -> Raw
mkSimpleCase f l e gs = RCase f l e (map mkAlt gs)
where mkAlt (Simple f l c args ret) = RAlt f l c args ret
mkAlt _ = error "Can't happen PMComp mkAlt"
-}
data Patt = ConPatt Name [Raw]
| VarPatt Name
| Constant Const
| ArrayPatt Int [Raw]
| DefaultPatt
| NoPatt
deriving (Show, Eq)
pcons x xs = RApply "foo" 1 (RVar "foo" 1 (UN "Cons")) [x, xs]
pvar x = RVar "foo" 1 (UN x)
pnil = RVar "foo" 1 (UN "Nil")
pint n = RConst "foo" 1 (Num n)
testAlts = [MAlt "foo" 1 [pcons (pvar "x") (pcons (pvar "y") (pvar "ys"))] (pint 2),
MAlt "foo" 1 [pcons (pvar "x") pnil] (pint 1),
MAlt "foo" 1 [pnil] (pint 0)]
isPatt (RApply _ _ (RVar _ _ con) args) _ _ = ConPatt con args
isPatt (RApply _ _ (RQVar _ _ con) args) _ _ = ConPatt con args
isPatt (RVar _ _ v) ctxt mod = vPatt v ctxt mod
isPatt (RQVar _ _ v) ctxt mod = vPatt v ctxt mod
isPatt (RConst _ _ c) ctxt mod = Constant c
isPatt (RArrayInit _ _ ps) ctxt mod = ArrayPatt (length ps) ps
isPatt p _ _ = NoPatt
-- If v is in the context, it can't be a local variable so let's treat it
-- as a constructor pattern. If it isn't a constructor pattern this will
-- fail harmlessly (and get caught by the typechecker) so no need for anything
-- more fancy.
vPatt v ctxt mod = case ctxtlookup mod v ctxt Nothing [] of
Just _ -> ConPatt v []
_ -> VarPatt v
-- If all the groups are in simple case expression form, we've won. Each
-- group is guaranteed to begin with a different constructor, so no need
-- to worry there. Returns the number of groups still to deal with, and
-- the new groupings after simplification.
{-
simplify :: [Group] -> (Int, [Group])
simplify gs = simpl gs [] (length gs)
where simpl [] acc i = (i, reverse acc)
simpl ((ConGroup n [(args, MAlt f l [] res)]):gs) acc i
| Just argnames <- mapM getName args
= simpl gs ((Simple f l n argnames res):acc) (i-1)
simpl (g:gs) acc i = simpl gs (g:acc) i
getName (RVar _ _ n) = Just n
getName (RQVar _ _ n) = Just n
getName _ = Nothing
-}
|