File: poly-gen.sml

package info (click to toggle)
mlton 20210117%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 58,464 kB
  • sloc: ansic: 27,682; sh: 4,455; asm: 3,569; lisp: 2,879; makefile: 2,347; perl: 1,169; python: 191; pascal: 68; javascript: 7
file content (180 lines) | stat: -rw-r--r-- 7,982 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
functor PolyGen
   (structure AstPP    : MDL_AST_PRETTY_PRINTER
    structure AstTrans : MDL_AST_TRANSLATION
       sharing AstPP.Ast = AstTrans.Ast
   ) : POLY_GEN =
struct
   structure Ast = AstPP.Ast
   structure A   = Ast
   structure T   = AstTrans
   structure H   = HashTable

   fun bug msg = MLRiscErrorMsg.error("PolyGen",msg)
   exception PolyGen
   fun error msg = (MDLError.error msg; raise PolyGen)
 
   datatype hook =
      HOOK of 
      { name   : string,      (* name of function *)
        factor : bool,        (* factor rules by constructor? *)
        args   : string list, (* function arguments *)
        ret    : string,      (* return argument name *)
        unit   : Ast.exp -> Ast.exp,
        gen    : (Ast.ty * Ast.exp -> Ast.exp) * Ast.consbind -> Ast.exp
      }  

   fun ID id = A.IDexp(A.IDENT([],id))
   val argExp = ID "arg"
   val argPat = A.IDpat "arg"

   fun gen(HOOK{name, args, ret, unit, factor, gen, ...}) isNonTerm ruleSet =
   let val redex = ID(hd args) (* the redex must be the first argument *)

       (*
        * Given a type, returns the appropriate function that performs the  
        * transformation
        *)
       val nullTrans = A.LAMBDAexp[A.CLAUSE([argPat],NONE,argExp)]
       fun ty2Exp(A.IDty(A.IDENT(_,id))) = 
              if isNonTerm id then SOME(ID(name^"'"^id)) else NONE
         | ty2Exp(A.APPty(A.IDENT(_,id),args)) =
              if isNonTerm id then 
                 let val args = map ty2Exp args
                 in  if List.exists Option.isSome args then
                       SOME(A.APPexp(ID(name^"'"^id),
                         A.TUPLEexp(map (fn SOME f => f | NONE => nullTrans)
                                    args)))
                     else NONE 
                 end
              else NONE
         | ty2Exp(A.TUPLEty tys) = (* create a functional *)
              let val args = map ty2Exp tys
                  fun bind([], i, pats, exps, some) = (rev pats, rev exps, some)
                    | bind(arg::args, i, pats, exps, some) = 
                      let val v = "v_"^Int.toString i
                          val pat = A.IDpat v
                          val exp = ID v
                          val (exp, some) = 
                             case arg of NONE => (exp, some)
                                       | SOME f => (A.APPexp(f,exp), true)
                      in  bind(args, i+1, pat::pats, exp::exps, some) end
                  val (pats, exps, some) = bind(args, 0, [], [], false)
              in  if some then
                     SOME(A.LAMBDAexp[A.CLAUSE([A.TUPLEpat pats],NONE,
                                                A.TUPLEexp exps)])
                  else NONE
              end
         | ty2Exp(A.RECORDty ltys) = 
              let val args = map (fn (l,t) => (l, ty2Exp t)) ltys
                  fun bind([], i, pats, exps, some) = (rev pats, rev exps, some)
                    | bind((l,arg)::args, i, pats, exps, some) = 
                      let val pat = (l, A.IDpat l)
                          val exp = ID l
                          val (exp, some) = 
                             case arg of NONE => (exp, some)
                                       | SOME f => (A.APPexp(f,exp), true)
                      in  bind(args, i+1, pat::pats, (l,exp)::exps, some)
                      end
                  val (pats, exps, some) = bind(args, 0, [], [], false)
              in  if some then
                     SOME(A.LAMBDAexp[A.CLAUSE([A.RECORDpat(pats,false)],NONE,
                                                A.RECORDexp exps)])
                  else NONE
              end
         | ty2Exp(A.TYVARty(A.VARtv id)) = SOME(ID("param"^id))
         | ty2Exp t = error("Can't handle type "^PP.text(AstPP.ty t))

       fun genOneRule(A.DATATYPEbind{id,tyvars,cbs, ...}, rules) =  
       let val prefix = []
           val subTerm = ref false
           fun appTrans(ty,e) =
               case ty2Exp ty of 
                 NONE => unit e
               | SOME f => (subTerm := true; A.APPexp(f,e))

           (* arguments for this function *)
           (* How to generate the traversal for one constructor *)
           fun genNonFactoredTraversal(cons)  =
           let val _ = subTerm := false
               val exp = gen(appTrans,cons)
               val exp = if !subTerm then exp else unit redex
               fun mapPat{origName,newName,ty} = A.IDpat newName
           in  T.mapConsToClause{prefix=prefix,pat=fn p => p,exp=exp} cons
           end

           exception Can'tFactor

           (* How to generate the traversal for one constructor *)
           fun genFactoredTraversal(cons as A.CONSbind{id, ty, ...},rules)  =
           let val _ = subTerm := false
               val resultExp as A.CONSexp(_,caseExp) = gen(appTrans,cons)
               val caseExp = case caseExp of SOME e => e
                                           | NONE => A.TUPLEexp []
               val body =
                   case (rules, !subTerm) of  
                     ([], false) => redex
                   | ([], true) => resultExp
                   | (_, _) =>
                     A.CASEexp
                      (caseExp, 
                       rules @
                       [A.CLAUSE([argPat],NONE,
                           case ty of 
                             SOME _ => A.CONSexp(A.IDENT([],id), SOME argExp)
                           | NONE => redex
                         ) 
                       ]
                      )
               fun mapPat{origName,newName,ty} = A.IDpat newName
           in  T.mapConsToClause{prefix=prefix,pat=fn p => p, exp=body} cons
           end

           (* first factor all rules by their top level constructors *)
           fun factorRules(rules) = 
           let exception Bad  
               val tbl = H.mkTable(HashString.hashString,op=)(32,Can'tFactor)   
               val _     = app (fn A.CONSbind{id,...} => H.insert tbl (id,[]))
                               cbs
               fun factor(r,A.CONSpat(A.IDENT([],id),arg),g,e) =
                   enterRule(r, id, arg, g, e)
                 | factor(r,A.IDpat id, g, e) =
                   enterRule(r, id, NONE, g, e)
                 | factor(r,A.ASpat(_,p), g, e) = factor(r, p, g, e)
                 | factor _ = raise Can'tFactor
               and factorRule(r as A.CLAUSE([p],g,e)) = factor(r,p,g,e)
                 | factorRule _ = raise Can'tFactor
               and enterRule(r, consName, arg, g, e) =
                  let val rs = H.lookup tbl consName
                      val r  = A.CLAUSE([case arg of NONE => A.WILDpat
                                                   | SOME p => p],g,e)  
                  in  H.insert tbl (consName,r::rs)
                  end 
               val _ = app factorRule rules
    
           in  map (fn c as A.CONSbind{id,...} => (c,rev(H.lookup tbl id))) cbs
           end

           fun factoredBody rules =
                A.CASEexp(redex,map genFactoredTraversal (factorRules rules))
           fun nonfactoredBody rules =
                A.LETexp([A.VALdecl[A.VALbind(A.IDpat ret, 
                          A.CASEexp(redex,map genNonFactoredTraversal cbs))]],
                         [A.CASEexp(A.TUPLEexp(map ID args), rules)]
                        )

           val body = 
               if factor then (factoredBody rules handle Can'tFactor =>
                               nonfactoredBody rules)
               else nonfactoredBody rules
           fun curriedArg(A.VARtv id) = A.IDpat("param"^id)
             | curriedArg _ = bug "curriedArg"
           val args = [A.TUPLEpat(map A.IDpat args)] 
           val args = case tyvars of
                        []  => args
                      | vs  => A.TUPLEpat(map curriedArg vs)::args
       in  A.FUNbind(name^"'"^id,[A.CLAUSE(args,NONE,body)]) 
       end
         | genOneRule _ = bug "genOneRule"
   in  A.FUNdecl(map genOneRule ruleSet)
   end
end