1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
|
functor PolyGen
(structure AstPP : MDL_AST_PRETTY_PRINTER
structure AstTrans : MDL_AST_TRANSLATION
sharing AstPP.Ast = AstTrans.Ast
) : POLY_GEN =
struct
structure Ast = AstPP.Ast
structure A = Ast
structure T = AstTrans
structure H = HashTable
fun bug msg = MLRiscErrorMsg.error("PolyGen",msg)
exception PolyGen
fun error msg = (MDLError.error msg; raise PolyGen)
datatype hook =
HOOK of
{ name : string, (* name of function *)
factor : bool, (* factor rules by constructor? *)
args : string list, (* function arguments *)
ret : string, (* return argument name *)
unit : Ast.exp -> Ast.exp,
gen : (Ast.ty * Ast.exp -> Ast.exp) * Ast.consbind -> Ast.exp
}
fun ID id = A.IDexp(A.IDENT([],id))
val argExp = ID "arg"
val argPat = A.IDpat "arg"
fun gen(HOOK{name, args, ret, unit, factor, gen, ...}) isNonTerm ruleSet =
let val redex = ID(hd args) (* the redex must be the first argument *)
(*
* Given a type, returns the appropriate function that performs the
* transformation
*)
val nullTrans = A.LAMBDAexp[A.CLAUSE([argPat],NONE,argExp)]
fun ty2Exp(A.IDty(A.IDENT(_,id))) =
if isNonTerm id then SOME(ID(name^"'"^id)) else NONE
| ty2Exp(A.APPty(A.IDENT(_,id),args)) =
if isNonTerm id then
let val args = map ty2Exp args
in if List.exists Option.isSome args then
SOME(A.APPexp(ID(name^"'"^id),
A.TUPLEexp(map (fn SOME f => f | NONE => nullTrans)
args)))
else NONE
end
else NONE
| ty2Exp(A.TUPLEty tys) = (* create a functional *)
let val args = map ty2Exp tys
fun bind([], i, pats, exps, some) = (rev pats, rev exps, some)
| bind(arg::args, i, pats, exps, some) =
let val v = "v_"^Int.toString i
val pat = A.IDpat v
val exp = ID v
val (exp, some) =
case arg of NONE => (exp, some)
| SOME f => (A.APPexp(f,exp), true)
in bind(args, i+1, pat::pats, exp::exps, some) end
val (pats, exps, some) = bind(args, 0, [], [], false)
in if some then
SOME(A.LAMBDAexp[A.CLAUSE([A.TUPLEpat pats],NONE,
A.TUPLEexp exps)])
else NONE
end
| ty2Exp(A.RECORDty ltys) =
let val args = map (fn (l,t) => (l, ty2Exp t)) ltys
fun bind([], i, pats, exps, some) = (rev pats, rev exps, some)
| bind((l,arg)::args, i, pats, exps, some) =
let val pat = (l, A.IDpat l)
val exp = ID l
val (exp, some) =
case arg of NONE => (exp, some)
| SOME f => (A.APPexp(f,exp), true)
in bind(args, i+1, pat::pats, (l,exp)::exps, some)
end
val (pats, exps, some) = bind(args, 0, [], [], false)
in if some then
SOME(A.LAMBDAexp[A.CLAUSE([A.RECORDpat(pats,false)],NONE,
A.RECORDexp exps)])
else NONE
end
| ty2Exp(A.TYVARty(A.VARtv id)) = SOME(ID("param"^id))
| ty2Exp t = error("Can't handle type "^PP.text(AstPP.ty t))
fun genOneRule(A.DATATYPEbind{id,tyvars,cbs, ...}, rules) =
let val prefix = []
val subTerm = ref false
fun appTrans(ty,e) =
case ty2Exp ty of
NONE => unit e
| SOME f => (subTerm := true; A.APPexp(f,e))
(* arguments for this function *)
(* How to generate the traversal for one constructor *)
fun genNonFactoredTraversal(cons) =
let val _ = subTerm := false
val exp = gen(appTrans,cons)
val exp = if !subTerm then exp else unit redex
fun mapPat{origName,newName,ty} = A.IDpat newName
in T.mapConsToClause{prefix=prefix,pat=fn p => p,exp=exp} cons
end
exception Can'tFactor
(* How to generate the traversal for one constructor *)
fun genFactoredTraversal(cons as A.CONSbind{id, ty, ...},rules) =
let val _ = subTerm := false
val resultExp as A.CONSexp(_,caseExp) = gen(appTrans,cons)
val caseExp = case caseExp of SOME e => e
| NONE => A.TUPLEexp []
val body =
case (rules, !subTerm) of
([], false) => redex
| ([], true) => resultExp
| (_, _) =>
A.CASEexp
(caseExp,
rules @
[A.CLAUSE([argPat],NONE,
case ty of
SOME _ => A.CONSexp(A.IDENT([],id), SOME argExp)
| NONE => redex
)
]
)
fun mapPat{origName,newName,ty} = A.IDpat newName
in T.mapConsToClause{prefix=prefix,pat=fn p => p, exp=body} cons
end
(* first factor all rules by their top level constructors *)
fun factorRules(rules) =
let exception Bad
val tbl = H.mkTable(HashString.hashString,op=)(32,Can'tFactor)
val _ = app (fn A.CONSbind{id,...} => H.insert tbl (id,[]))
cbs
fun factor(r,A.CONSpat(A.IDENT([],id),arg),g,e) =
enterRule(r, id, arg, g, e)
| factor(r,A.IDpat id, g, e) =
enterRule(r, id, NONE, g, e)
| factor(r,A.ASpat(_,p), g, e) = factor(r, p, g, e)
| factor _ = raise Can'tFactor
and factorRule(r as A.CLAUSE([p],g,e)) = factor(r,p,g,e)
| factorRule _ = raise Can'tFactor
and enterRule(r, consName, arg, g, e) =
let val rs = H.lookup tbl consName
val r = A.CLAUSE([case arg of NONE => A.WILDpat
| SOME p => p],g,e)
in H.insert tbl (consName,r::rs)
end
val _ = app factorRule rules
in map (fn c as A.CONSbind{id,...} => (c,rev(H.lookup tbl id))) cbs
end
fun factoredBody rules =
A.CASEexp(redex,map genFactoredTraversal (factorRules rules))
fun nonfactoredBody rules =
A.LETexp([A.VALdecl[A.VALbind(A.IDpat ret,
A.CASEexp(redex,map genNonFactoredTraversal cbs))]],
[A.CASEexp(A.TUPLEexp(map ID args), rules)]
)
val body =
if factor then (factoredBody rules handle Can'tFactor =>
nonfactoredBody rules)
else nonfactoredBody rules
fun curriedArg(A.VARtv id) = A.IDpat("param"^id)
| curriedArg _ = bug "curriedArg"
val args = [A.TUPLEpat(map A.IDpat args)]
val args = case tyvars of
[] => args
| vs => A.TUPLEpat(map curriedArg vs)::args
in A.FUNbind(name^"'"^id,[A.CLAUSE(args,NONE,body)])
end
| genOneRule _ = bug "genOneRule"
in A.FUNdecl(map genOneRule ruleSet)
end
end
|