File: Internal.hs

package info (click to toggle)
haskell-ghc-tcplugins-extra 0.4.6-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 316 kB
  • sloc: haskell: 901; makefile: 6
file content (133 lines) | stat: -rw-r--r-- 4,950 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
{-# LANGUAGE RecordWildCards #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module Internal
  ( -- * Create new constraints
    TcPlugin.newWanted
  , newGiven
    -- * Creating evidence
  , evByFiat
    -- * Lookup
  , lookupModule
  , lookupName
    -- * Trace state of the plugin
  , tracePlugin
    -- * Substitutions
  , flattenGivens
  , mkSubst
  , mkSubst'
  , substType
  , substCt
  )
where

import GHC.Driver.Config.Finder (initFinderOpts)
import GHC.Tc.Plugin (TcPluginM, lookupOrig, tcPluginTrace)
import qualified GHC.Tc.Plugin as TcPlugin
  (newWanted, getTopEnv, tcPluginIO, findImportedModule)
import GHC.Tc.Types (TcPlugin(..), TcPluginSolveResult(..))
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
import GHC.Tc.Utils.TcType (TcType)
import Data.Maybe (mapMaybe)

import GhcApi.Constraint (Ct(..))
import GhcApi.GhcPlugins

import Internal.Type (substType)
import Internal.Constraint (newGiven, flatToCt, mkSubst, overEvidencePredType)
import Internal.Evidence (evByFiat)

-- | Find a module
lookupModule :: ModuleName -- ^ Name of the module
             -> FastString -- ^ Name of the package containing the module.
                           -- NOTE: This value is ignored on ghc>=8.0.
             -> TcPluginM Module
lookupModule mod_nm _pkg = do
  hsc_env <- TcPlugin.getTopEnv
  let fc         = hsc_FC hsc_env
      dflags     = hsc_dflags hsc_env
      fopts      = initFinderOpts dflags
      units      = hsc_units hsc_env
      mhome_unit = hsc_home_unit_maybe hsc_env
  found_module <- TcPlugin.tcPluginIO $ findPluginModule fc fopts units
                                          mhome_unit mod_nm
  case found_module of
    Found _ h -> return h
    _ -> do
      let pkg_qual = maybe NoPkgQual (ThisPkg . homeUnitId) mhome_unit
      found_module' <- TcPlugin.findImportedModule mod_nm pkg_qual
      case found_module' of
        Found _ h -> return h
        _ -> panicDoc "Couldn't find module" (ppr mod_nm)

-- | Find a 'Name' in a 'Module' given an 'OccName'
lookupName :: Module -> OccName -> TcPluginM Name
lookupName = lookupOrig

-- | Print out extra information about the initialisation, stop, and every run
-- of the plugin when @-ddump-tc-trace@ is enabled.
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin s TcPlugin{..} = TcPlugin { tcPluginInit    = traceInit
                                      , tcPluginSolve   = traceSolve
                                      , tcPluginRewrite = tcPluginRewrite
                                      , tcPluginStop    = traceStop
                                      }
  where
    traceInit = do
      tcPluginTrace ("tcPluginInit " ++ s) empty >> tcPluginInit

    traceStop  z = tcPluginTrace ("tcPluginStop " ++ s) empty >> tcPluginStop z

    traceSolve z ev given wanted = do
      tcPluginTrace ("tcPluginSolve start " ++ s)
                        (text "given   =" <+> ppr given
                      $$ text "wanted  =" <+> ppr wanted)
      r <- tcPluginSolve z ev given wanted
      case r of
        TcPluginOk solved new
          -> tcPluginTrace ("tcPluginSolve ok " ++ s)
                           (text "solved =" <+> ppr solved
                         $$ text "new    =" <+> ppr new)
        TcPluginContradiction bad
          -> tcPluginTrace ("tcPluginSolve contradiction " ++ s)
                           (text "bad =" <+> ppr bad)
        TcPluginSolveResult bad solved new
          -> tcPluginTrace ("tcPluginSolveResult " ++ s)
                           (text "solved =" <+> ppr solved
                         $$ text "bad    =" <+> ppr bad
                         $$ text "new    =" <+> ppr new)
      return r

-- | Flattens evidence of constraints by substituting each others equalities.
--
-- __NB:__ Should only be used on /[G]iven/ constraints!
--
-- __NB:__ Doesn't flatten under binders
flattenGivens :: [Ct] -> [Ct]
flattenGivens givens =
  mapMaybe flatToCt flat ++ map (substCt subst') givens
 where
  subst = mkSubst' givens
  (flat,subst')
    = second (map fst . concat)
    $ partition ((>= 2) . length)
    $ groupBy ((==) `on` (fst.fst))
    $ sortOn (fst.fst) subst

-- | Create flattened substitutions from type equalities, i.e. the substitutions
-- have been applied to each others right hand sides.
mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' = foldr substSubst [] . mapMaybe mkSubst
 where
  substSubst :: ((TcTyVar,TcType),Ct)
             -> [((TcTyVar,TcType),Ct)]
             -> [((TcTyVar,TcType),Ct)]
  substSubst ((tv,t),ct) s = ((tv,substType (map fst s) t),ct)
                           : map (first (second (substType [(tv,t)]))) s

-- | Apply substitution in the evidence of Cts
substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt subst = overEvidencePredType (substType subst)