1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
|
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TupleSections #-}
-- | Adds cost-centers to call sites selected with the @-fprof-caller=...@
-- flag.
module GHC.Core.Opt.CallerCC
( addCallerCostCentres
, CallerCcFilter(..)
, NamePattern(..)
, parseCallerCcFilter
) where
import Data.Word (Word8)
import Data.Maybe
import Control.Applicative
import GHC.Utils.Monad.State.Strict
import Data.Either
import Control.Monad
import qualified Text.ParserCombinators.ReadP as P
import GHC.Prelude
import GHC.Utils.Outputable as Outputable
import GHC.Driver.Session
import GHC.Types.CostCentre
import GHC.Types.CostCentre.State
import GHC.Types.Name hiding (varName)
import GHC.Types.Tickish
import GHC.Unit.Module.ModGuts
import GHC.Types.SrcLoc
import GHC.Types.Var
import GHC.Unit.Types
import GHC.Data.FastString
import GHC.Core
import GHC.Core.Opt.Monad
import GHC.Utils.Panic
import qualified GHC.Utils.Binary as B
import Data.Char
import Language.Haskell.Syntax.Module.Name
addCallerCostCentres :: ModGuts -> CoreM ModGuts
addCallerCostCentres guts = do
dflags <- getDynFlags
let filters = callerCcFilters dflags
let env :: Env
env = Env
{ thisModule = mg_module guts
, ccState = newCostCentreState
, countEntries = gopt Opt_ProfCountEntries dflags
, revParents = []
, filters = filters
}
let guts' = guts { mg_binds = doCoreProgram env (mg_binds guts)
}
return guts'
doCoreProgram :: Env -> CoreProgram -> CoreProgram
doCoreProgram env binds = flip evalState newCostCentreState $ do
mapM (doBind env) binds
doBind :: Env -> CoreBind -> M CoreBind
doBind env (NonRec b rhs) = NonRec b <$> doExpr (addParent b env) rhs
doBind env (Rec bs) = Rec <$> mapM doPair bs
where
doPair (b,rhs) = (b,) <$> doExpr (addParent b env) rhs
doExpr :: Env -> CoreExpr -> M CoreExpr
doExpr env e@(Var v)
| needsCallSiteCostCentre env v = do
let nameDoc :: SDoc
nameDoc = withUserStyle alwaysQualify DefaultDepth $
hcat (punctuate dot (map ppr (parents env))) <> parens (text "calling:" <> ppr v)
ccName :: CcName
ccName = mkFastString $ renderWithContext defaultSDocContext nameDoc
ccIdx <- getCCIndex' ccName
let count = countEntries env
span = case revParents env of
top:_ -> nameSrcSpan $ varName top
_ -> noSrcSpan
cc = NormalCC (mkExprCCFlavour ccIdx) ccName (thisModule env) span
tick :: CoreTickish
tick = ProfNote cc count True
pure $ Tick tick e
| otherwise = pure e
doExpr _env e@(Lit _) = pure e
doExpr env (f `App` x) = App <$> doExpr env f <*> doExpr env x
doExpr env (Lam b x) = Lam b <$> doExpr env x
doExpr env (Let b rhs) = Let <$> doBind env b <*> doExpr env rhs
doExpr env (Case scrut b ty alts) =
Case <$> doExpr env scrut <*> pure b <*> pure ty <*> mapM doAlt alts
where
doAlt (Alt con bs rhs) = Alt con bs <$> doExpr env rhs
doExpr env (Cast expr co) = Cast <$> doExpr env expr <*> pure co
doExpr env (Tick t e) = Tick t <$> doExpr env e
doExpr _env e@(Type _) = pure e
doExpr _env e@(Coercion _) = pure e
type M = State CostCentreState
getCCIndex' :: FastString -> M CostCentreIndex
getCCIndex' name = state (getCCIndex name)
data Env = Env
{ thisModule :: Module
, countEntries :: !Bool
, ccState :: CostCentreState
, revParents :: [Id]
, filters :: [CallerCcFilter]
}
addParent :: Id -> Env -> Env
addParent i env = env { revParents = i : revParents env }
parents :: Env -> [Id]
parents env = reverse (revParents env)
needsCallSiteCostCentre :: Env -> Id -> Bool
needsCallSiteCostCentre env i =
any matches (filters env)
where
matches :: CallerCcFilter -> Bool
matches ccf =
checkModule && checkFunc
where
checkModule =
case ccfModuleName ccf of
Just modFilt
| Just iMod <- nameModule_maybe (varName i)
-> moduleName iMod == modFilt
| otherwise -> False
Nothing -> True
checkFunc =
occNameMatches (ccfFuncName ccf) (getOccName i)
data NamePattern
= PChar Char NamePattern
| PWildcard NamePattern
| PEnd
instance Outputable NamePattern where
ppr (PChar c rest) = char c <> ppr rest
ppr (PWildcard rest) = char '*' <> ppr rest
ppr PEnd = Outputable.empty
instance B.Binary NamePattern where
get bh = do
tag <- B.get bh
case tag :: Word8 of
0 -> PChar <$> B.get bh <*> B.get bh
1 -> PWildcard <$> B.get bh
2 -> pure PEnd
_ -> panic "Binary(NamePattern): Invalid tag"
put_ bh (PChar x y) = B.put_ bh (0 :: Word8) >> B.put_ bh x >> B.put_ bh y
put_ bh (PWildcard x) = B.put_ bh (1 :: Word8) >> B.put_ bh x
put_ bh PEnd = B.put_ bh (2 :: Word8)
occNameMatches :: NamePattern -> OccName -> Bool
occNameMatches pat = go pat . occNameString
where
go :: NamePattern -> String -> Bool
go PEnd "" = True
go (PChar c rest) (d:s)
= d == c && go rest s
go (PWildcard rest) s
= go rest s || go (PWildcard rest) (tail s)
go _ _ = False
type Parser = P.ReadP
parseNamePattern :: Parser NamePattern
parseNamePattern = pattern
where
pattern = star P.<++ wildcard P.<++ char P.<++ end
star = PChar '*' <$ P.string "\\*" <*> pattern
wildcard = do
void $ P.char '*'
PWildcard <$> pattern
char = PChar <$> P.get <*> pattern
end = PEnd <$ P.eof
data CallerCcFilter
= CallerCcFilter { ccfModuleName :: Maybe ModuleName
, ccfFuncName :: NamePattern
}
instance Outputable CallerCcFilter where
ppr ccf =
maybe (char '*') ppr (ccfModuleName ccf)
<> char '.'
<> ppr (ccfFuncName ccf)
instance B.Binary CallerCcFilter where
get bh = CallerCcFilter <$> B.get bh <*> B.get bh
put_ bh (CallerCcFilter x y) = B.put_ bh x >> B.put_ bh y
parseCallerCcFilter :: String -> Either String CallerCcFilter
parseCallerCcFilter inp =
case P.readP_to_S parseCallerCcFilter' inp of
((result, ""):_) -> Right result
_ -> Left $ "parse error on " ++ inp
parseCallerCcFilter' :: Parser CallerCcFilter
parseCallerCcFilter' =
CallerCcFilter
<$> moduleFilter
<* P.char '.'
<*> parseNamePattern
where
moduleFilter :: Parser (Maybe ModuleName)
moduleFilter =
(Just . mkModuleName <$> moduleName)
<|>
(Nothing <$ P.char '*')
moduleName :: Parser String
moduleName = do
c <- P.satisfy isUpper
cs <- P.munch1 (\c -> isUpper c || isLower c || isDigit c || c == '_')
rest <- optional $ P.char '.' >> fmap ('.':) moduleName
return $ c : (cs ++ fromMaybe "" rest)
|