File: Traversal.hs

package info (click to toggle)
haskell-derive 2.5.16-1
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 460 kB
  • sloc: haskell: 3,686; makefile: 5
file content (210 lines) | stat: -rw-r--r-- 9,702 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
{-# LANGUAGE CPP #-}
{-
    This module is not written/maintained by the usual Data.Derive author.

    MAINTAINER: Twan van Laarhoven 
    EMAIL: "twanvl" ++ "@" ++ "gmail" ++ "." ++ "com"

    Please send all patches to this module to Neil (ndmitchell -at- gmail),
    and CC Twan.
-}

-- NOTE: Cannot be guessed as it relies on type information

-- | Derives 'Functor' and similair classes, as discussed on the Haskell-prime mailing list:
-- <http://www.mail-archive.com/haskell-prime@haskell.org/msg02116.html>.
module Data.Derive.Internal.Traversal(
        TraveralType(..), defaultTraversalType,
        traversalDerivation1,
        traversalInstance, traversalInstance1,
        deriveTraversal
    ) where

import Language.Haskell
import Data.Derive.Internal.Derivation
import Data.List
import qualified Data.Set as S
import Control.Monad.Trans.Writer
import Control.Applicative
import Data.Generics.Uniplate.DataOnly
import Data.Maybe

---------------------------------------------------------------------------------
-- Information datatype, public interface

-- | An expression representing a traversal of a subpart of the data
type Trav = Exp

-- | What kind of traversal are we deriving?
data TraveralType = TraveralType
        { traversalArg    :: Int                     -- ^ On what position are we traversing?
        , traversalCo     :: Bool                    -- ^ covariant?
        , traversalName   :: QName                   -- ^ name of the traversal function
        , traversalId     :: Trav                    -- ^ Identity traversal
        , traversalDirect :: Trav                    -- ^ Traversal of 'a'
        , traversalFunc   :: QName -> Trav -> Trav  -- ^ Apply the sub-traversal function
        , traversalPlus   :: Trav -> Trav -> Trav    -- ^ Apply two non-identity traversals in sequence
        , traverseArrow   :: Maybe (Trav -> Trav -> Trav)    -- ^ Traverse a function type
        , traverseTuple   :: [Exp] -> Exp            -- ^ Construct a tuple from applied traversals
        , traverseCtor    :: String -> [Exp] -> Exp  -- ^ Construct a data type from applied traversals
        , traverseFunc    :: Pat -> Exp -> Match     -- ^ Construct a clause of the traversal function
        }

defaultTraversalType = TraveralType
        { traversalArg    = 1
        , traversalCo     = False
        , traversalName   = undefined -- prevent warnings
        , traversalId     = var "id"
        , traversalDirect = var "_f"
        , traversalFunc   = \x y -> appP (Var x) y
        , traversalPlus   = \x y -> apps (Con $ Special Cons) [paren x, paren y]
        , traverseArrow   = Nothing
        , traverseTuple   = Tuple Boxed
        , traverseCtor    = \x y -> apps (con x) (map paren y)
        , traverseFunc    = undefined
        }

data RequiredInstance = RequiredInstance
        { _requiredDataArg  :: String -- ^ What argument of the current data type?
        , _requiredPosition :: Int    -- ^ What argument position of that type?
        }
      deriving (Eq, Ord)

-- | Monad that collects required instances
type WithInstances a = Writer (S.Set RequiredInstance) a


vars f c n = [f $ c : show i | i <- [1..n]]


---------------------------------------------------------------------------------
-- Deriving traversals


-- | Derivation for a Traversable like class with just 1 method
traversalDerivation1 :: TraveralType -> String -> Derivation
traversalDerivation1 tt nm = derivationCustom (className $ traversalArg tt) (traversalInstance1 tt nm) 
    where className n = nm ++ (if n > 1 then show n else "")


-- | Instance for a Traversable like class with just 1 method
traversalInstance1 :: TraveralType -> String -> FullDataDecl -> Either String [Decl]
traversalInstance1 tt nm (_,dat)
    | isNothing (traverseArrow tt) && any isTyFun (universeBi dat) = Left $ "Can't derive " ++ prettyPrint (traversalName tt) ++ " for types with arrow"
    | dataDeclArity dat == 0 = Left "Cannot derive class for data type arity == 0"
    | otherwise = Right $ traversalInstance tt nm dat [deriveTraversal tt dat]


-- | Instance for a Traversable like class
traversalInstance :: TraveralType -> String -> DataDecl -> [WithInstances Decl] -> [Decl]
traversalInstance tt nameBase dat bodyM = [simplify $ InstDecl sl ctx nam args (map InsDecl body)]
    where
        (body, required) = runWriter (sequence bodyM)
        ctx  = [ ClassA (qname $ className p) (tyVar n : vars tyVar 's' (p - 1))
               | RequiredInstance n p <- S.toList required
               ]
        vrs  = vars tyVar 't' (dataDeclArity dat)
        (vrsBefore,_:vrsAfter) = splitAt (length vrs - traversalArg tt) vrs
        className n = nameBase ++ (if n > 1 then show n else "")
        nam = qname (className (traversalArg tt))
        args = TyParen (tyApps (tyCon $ dataDeclName dat) vrsBefore) : vrsAfter


-- | Derive a 'traverse' like function
deriveTraversal :: TraveralType -> DataDecl -> WithInstances Decl
deriveTraversal tt dat = fun
    where
        fun  = (\xs -> FunBind [Match sl nam a b c d | Match _ _ a b c d <- xs]) <$> body
        args = argPositions dat
        nam = unqual $ traversalNameN tt $ traversalArg tt
        body = mapM (deriveTraversalCtor tt args) (dataDeclCtors dat)

        unqual (Qual _ x) = x
        unqual (UnQual x) = x


-- | Derive a clause of a 'traverse' like function for a constructor
deriveTraversalCtor :: TraveralType -> ArgPositions -> CtorDecl -> WithInstances Match
deriveTraversalCtor tt ap ctor = do
    let nam = ctorDeclName ctor
        arity = ctorDeclArity ctor
    tTypes <- mapM (deriveTraversalType tt ap) (map (fromBangType . snd) $ ctorDeclFields ctor)
    return $ traverseFunc tt (PParen $ PApp (qname nam) (vars pVar 'a' arity))
           $ traverseCtor tt nam (zipWith App tTypes (vars var 'a' arity))



-- | Derive a traversal for a type
deriveTraversalType :: TraveralType -> ArgPositions -> Type -> WithInstances Trav
deriveTraversalType tt ap (TyParen x) = deriveTraversalType tt ap x
deriveTraversalType tt ap TyForall{}  = fail "forall not supported in traversal deriving"
deriveTraversalType tt ap (TyFun a b)
                                           = fromJust (traverseArrow tt)
                                                 <$> deriveTraversalType tt{traversalCo = not $ traversalCo tt} ap a
                                                 <*> deriveTraversalType tt                                     ap b
deriveTraversalType tt ap (TyApp a b)      = deriveTraversalApp tt ap a [b] -- T a b c ...
deriveTraversalType tt ap (TyList a)       = deriveTraversalType tt ap $ TyApp (TyCon $ Special ListCon) a
deriveTraversalType tt ap (TyTuple b a)    = deriveTraversalType tt ap $ tyApps (TyCon $ Special $ TupleCon b $ length a) a
deriveTraversalType tt ap (TyCon n)        = return $ traversalId tt -- T
deriveTraversalType tt ap (TyVar (Ident n)) -- a
  | ap n /= traversalArg tt                = return $ traversalId tt
  | traversalCo tt                         = fail "tyvar used in covariant position"
  | otherwise                              = return $ traversalDirect tt


-- | Find all arguments to a type application, then derive a traversal
deriveTraversalApp :: TraveralType -> ArgPositions -> Type -> [Type] -> WithInstances Trav
deriveTraversalApp tt ap (TyApp a b) args = deriveTraversalApp tt ap a (b : args)
deriveTraversalApp tt ap tycon@TyTuple{} args = do -- (a,b,c)
         tArgs <- mapM (deriveTraversalType tt ap) args
         return $
           if (all (== traversalId tt) tArgs) then
             traversalId tt
           else
             Lambda sl [PTuple Boxed (vars pVar 't' (length args))]
                  (traverseTuple tt $ zipWith App tArgs (vars var 't' (length args)))
deriveTraversalApp tt ap tycon args = do -- T a b c
         tCon  <- deriveTraversalType tt ap tycon
         tArgs <- mapM (deriveTraversalType tt ap) args
         -- need instances?
         case tycon of
           TyVar (Ident n) | ap n == traversalArg tt -> fail "kind error: type used type constructor"
                   | otherwise               -> tell $ S.fromList
                                                [ RequiredInstance n i
                                                | (t,i) <- zip (reverse tArgs) [1..]
                                                , t /= traversalId tt
                                                ]
           _ -> return ()
         -- combine non-id traversals
         let nonId = [ traverseArg tt i t
                     | (t,i) <- zip (reverse tArgs) [1..]
                     , t /= traversalId tt
                     ]
         return $ case nonId of
           [] -> traversalId tt -- no interesting arguments to type con
           _  -> foldl1 (traversalPlus tt) nonId


-- | Lift a traversal to the argument of a type constructor
traverseArg :: TraveralType -> Int -> Trav -> Trav
traverseArg tt n e   =  traversalFunc tt (traversalNameN tt n) e

traversalNameN :: TraveralType -> Int -> QName
traversalNameN tt n | n <= 1    = nm
                    | otherwise = nm `f` (if n > 1 then show n else "")
  where nm = traversalName tt
        f (Qual m x) y = Qual m $ x `g` y
        f (UnQual x) y = UnQual $ x `g` y
        g (Ident x) y = Ident $ x ++ y

-- | Information on argument positions
type ArgPositions = String -> Int

-- | Position of an argument in the data type
--   In the type  "data X a b c"
--   positions are: a -> 3, b -> 2, c -> 1
argPositions :: DataDecl -> String -> Int
argPositions dat = \nm -> case elemIndex nm args of
    Nothing -> error "impossible: tyvar not in scope"
    Just k  -> length args - k
 where args = dataDeclVars dat