1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module Internal
( -- * Create new constraints
newWanted
, newGiven
, newDerived
-- * Creating evidence
, evByFiat
-- * Lookup
, lookupModule
, lookupName
-- * Trace state of the plugin
, tracePlugin
-- * Substitutions
, flattenGivens
, mkSubst
, mkSubst'
, substType
, substCt
)
where
import GHC.Tc.Plugin (TcPluginM, lookupOrig, tcPluginTrace)
import qualified GHC.Tc.Plugin as TcPlugin
(newDerived, newWanted, getTopEnv, tcPluginIO, findImportedModule)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult(..))
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
import GHC.Tc.Utils.TcType (TcType)
import Data.Maybe (mapMaybe)
import GhcApi.Constraint (Ct(..), CtEvidence(..), CtLoc)
import GhcApi.GhcPlugins
import Internal.Type (substType)
import Internal.Constraint (newGiven, flatToCt, mkSubst, overEvidencePredType)
import Internal.Evidence (evByFiat)
{-# ANN fr_mod "HLint: ignore Use camelCase" #-}
pattern FoundModule :: Module -> FindResult
pattern FoundModule a <- Found _ a
fr_mod :: a -> a
fr_mod = id
-- | Create a new [W]anted constraint.
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
newWanted = TcPlugin.newWanted
-- | Create a new [D]erived constraint.
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
newDerived = TcPlugin.newDerived
-- | Find a module
lookupModule :: ModuleName -- ^ Name of the module
-> FastString -- ^ Name of the package containing the module.
-- NOTE: This value is ignored on ghc>=8.0.
-> TcPluginM Module
lookupModule mod_nm _pkg = do
hsc_env <- TcPlugin.getTopEnv
found_module <- TcPlugin.tcPluginIO $ findPluginModule hsc_env mod_nm
case found_module of
FoundModule h -> return (fr_mod h)
_ -> do
found_module' <- TcPlugin.findImportedModule mod_nm $ Just $ fsLit "this"
case found_module' of
FoundModule h -> return (fr_mod h)
_ -> panicDoc "Couldn't find module" (ppr mod_nm)
-- | Find a 'Name' in a 'Module' given an 'OccName'
lookupName :: Module -> OccName -> TcPluginM Name
lookupName = lookupOrig
-- | Print out extra information about the initialisation, stop, and every run
-- of the plugin when @-ddump-tc-trace@ is enabled.
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin s TcPlugin{..} = TcPlugin { tcPluginInit = traceInit
, tcPluginSolve = traceSolve
, tcPluginStop = traceStop
}
where
traceInit = do
-- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301
initializeStaticFlags
tcPluginTrace ("tcPluginInit " ++ s) empty >> tcPluginInit
traceStop z = tcPluginTrace ("tcPluginStop " ++ s) empty >> tcPluginStop z
traceSolve z given derived wanted = do
tcPluginTrace ("tcPluginSolve start " ++ s)
(text "given =" <+> ppr given
$$ text "derived =" <+> ppr derived
$$ text "wanted =" <+> ppr wanted)
r <- tcPluginSolve z given derived wanted
case r of
TcPluginOk solved new -> tcPluginTrace ("tcPluginSolve ok " ++ s)
(text "solved =" <+> ppr solved
$$ text "new =" <+> ppr new)
TcPluginContradiction bad -> tcPluginTrace
("tcPluginSolve contradiction " ++ s)
(text "bad =" <+> ppr bad)
return r
-- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301
initializeStaticFlags :: TcPluginM ()
initializeStaticFlags = return ()
-- | Flattens evidence of constraints by substituting each others equalities.
--
-- __NB:__ Should only be used on /[G]iven/ constraints!
--
-- __NB:__ Doesn't flatten under binders
flattenGivens :: [Ct] -> [Ct]
flattenGivens givens =
mapMaybe flatToCt flat ++ map (substCt subst') givens
where
subst = mkSubst' givens
(flat,subst')
= second (map fst . concat)
$ partition ((>= 2) . length)
$ groupBy ((==) `on` (fst.fst))
$ sortOn (fst.fst) subst
-- | Create flattened substitutions from type equalities, i.e. the substitutions
-- have been applied to each others right hand sides.
mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' = foldr substSubst [] . mapMaybe mkSubst
where
substSubst :: ((TcTyVar,TcType),Ct)
-> [((TcTyVar,TcType),Ct)]
-> [((TcTyVar,TcType),Ct)]
substSubst ((tv,t),ct) s = ((tv,substType (map fst s) t),ct)
: map (first (second (substType [(tv,t)]))) s
-- | Apply substitution in the evidence of Cts
substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt subst = overEvidencePredType (substType subst)
|