File: Internal.hs

package info (click to toggle)
haskell-ghc-tcplugins-extra 0.4.6-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 316 kB
  • sloc: haskell: 901; makefile: 6
file content (143 lines) | stat: -rw-r--r-- 5,137 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternSynonyms #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module Internal
  ( -- * Create new constraints
    newWanted
  , newGiven
  , newDerived
    -- * Creating evidence
  , evByFiat
    -- * Lookup
  , lookupModule
  , lookupName
    -- * Trace state of the plugin
  , tracePlugin
    -- * Substitutions
  , flattenGivens
  , mkSubst
  , mkSubst'
  , substType
  , substCt
  )
where

import GHC.Tc.Plugin (TcPluginM, lookupOrig, tcPluginTrace)
import qualified GHC.Tc.Plugin as TcPlugin
    (newDerived, newWanted, getTopEnv, tcPluginIO, findImportedModule)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult(..))
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
import GHC.Tc.Utils.TcType (TcType)
import Data.Maybe (mapMaybe)

import GhcApi.Constraint (Ct(..), CtEvidence(..), CtLoc)
import GhcApi.GhcPlugins

import Internal.Type (substType)
import Internal.Constraint (newGiven, flatToCt, mkSubst, overEvidencePredType)
import Internal.Evidence (evByFiat)

{-# ANN fr_mod "HLint: ignore Use camelCase" #-}

pattern FoundModule :: Module -> FindResult
pattern FoundModule a <- Found _ a

fr_mod :: a -> a
fr_mod = id

-- | Create a new [W]anted constraint.
newWanted  :: CtLoc -> PredType -> TcPluginM CtEvidence
newWanted = TcPlugin.newWanted

-- | Create a new [D]erived constraint.
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
newDerived = TcPlugin.newDerived

-- | Find a module
lookupModule :: ModuleName -- ^ Name of the module
             -> FastString -- ^ Name of the package containing the module.
                           -- NOTE: This value is ignored on ghc>=8.0.
             -> TcPluginM Module
lookupModule mod_nm _pkg = do
  hsc_env <- TcPlugin.getTopEnv
  found_module <- TcPlugin.tcPluginIO $ findPluginModule hsc_env mod_nm
  case found_module of
    FoundModule h -> return (fr_mod h)
    _          -> do
      found_module' <- TcPlugin.findImportedModule mod_nm $ Just $ fsLit "this"
      case found_module' of
        FoundModule h -> return (fr_mod h)
        _ -> panicDoc "Couldn't find module" (ppr mod_nm)

-- | Find a 'Name' in a 'Module' given an 'OccName'
lookupName :: Module -> OccName -> TcPluginM Name
lookupName = lookupOrig

-- | Print out extra information about the initialisation, stop, and every run
-- of the plugin when @-ddump-tc-trace@ is enabled.
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin s TcPlugin{..} = TcPlugin { tcPluginInit  = traceInit
                                      , tcPluginSolve = traceSolve
                                      , tcPluginStop  = traceStop
                                      }
  where
    traceInit = do
      -- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301
      initializeStaticFlags
      tcPluginTrace ("tcPluginInit " ++ s) empty >> tcPluginInit

    traceStop  z = tcPluginTrace ("tcPluginStop " ++ s) empty >> tcPluginStop z

    traceSolve z given derived wanted = do
      tcPluginTrace ("tcPluginSolve start " ++ s)
                        (text "given   =" <+> ppr given
                      $$ text "derived =" <+> ppr derived
                      $$ text "wanted  =" <+> ppr wanted)
      r <- tcPluginSolve z given derived wanted
      case r of
        TcPluginOk solved new     -> tcPluginTrace ("tcPluginSolve ok " ++ s)
                                         (text "solved =" <+> ppr solved
                                       $$ text "new    =" <+> ppr new)
        TcPluginContradiction bad -> tcPluginTrace
                                         ("tcPluginSolve contradiction " ++ s)
                                         (text "bad =" <+> ppr bad)
      return r

-- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301
initializeStaticFlags :: TcPluginM ()
initializeStaticFlags = return ()

-- | Flattens evidence of constraints by substituting each others equalities.
--
-- __NB:__ Should only be used on /[G]iven/ constraints!
--
-- __NB:__ Doesn't flatten under binders
flattenGivens :: [Ct] -> [Ct]
flattenGivens givens =
  mapMaybe flatToCt flat ++ map (substCt subst') givens
 where
  subst = mkSubst' givens
  (flat,subst')
    = second (map fst . concat)
    $ partition ((>= 2) . length)
    $ groupBy ((==) `on` (fst.fst))
    $ sortOn (fst.fst) subst

-- | Create flattened substitutions from type equalities, i.e. the substitutions
-- have been applied to each others right hand sides.
mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' = foldr substSubst [] . mapMaybe mkSubst
 where
  substSubst :: ((TcTyVar,TcType),Ct)
             -> [((TcTyVar,TcType),Ct)]
             -> [((TcTyVar,TcType),Ct)]
  substSubst ((tv,t),ct) s = ((tv,substType (map fst s) t),ct)
                           : map (first (second (substType [(tv,t)]))) s

-- | Apply substitution in the evidence of Cts
substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt subst = overEvidencePredType (substType subst)