1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
module Database.Persist.GenericSql.Migration
( Migration
, parseMigration
, parseMigration'
, printMigration
, getMigration
, runMigration
, runMigrationSilent
, runMigrationUnsafe
, migrate
, commit
, rollback
) where
import Database.Persist.GenericSql.Internal
import Database.Persist.EntityDef
import qualified Database.Persist.GenericSql.Raw as R
import Database.Persist.Store
import Database.Persist.GenericSql.Raw (SqlPersist (..))
#if MIN_VERSION_monad_control(0, 3, 0)
import Control.Monad.Trans.Control (MonadBaseControl)
#define MBCIO MonadBaseControl IO
#else
import Control.Monad.IO.Control (MonadControlIO)
#define MBCIO MonadControlIO
#endif
import Control.Monad.Trans.Class (MonadTrans (..))
import Control.Monad.IO.Class
import Control.Monad.Trans.Reader
import Control.Monad.Trans.Writer
import Control.Monad (liftM, unless)
import Data.Text (Text, unpack, snoc)
import qualified Data.Text.IO
import System.IO
execute' :: MonadIO m => Text -> [PersistValue] -> SqlPersist m ()
execute' = R.execute
type Sql = Text
-- Bool indicates if the Sql is safe
type CautiousMigration = [(Bool, Sql)]
allSql :: CautiousMigration -> [Sql]
allSql = map snd
unsafeSql :: CautiousMigration -> [Sql]
unsafeSql = allSql . filter fst
safeSql :: CautiousMigration -> [Sql]
safeSql = allSql . filter (not . fst)
type Migration m = WriterT [Text] (WriterT CautiousMigration m) ()
parseMigration :: Monad m => Migration m -> m (Either [Text] CautiousMigration)
parseMigration =
liftM go . runWriterT . execWriterT
where
go ([], sql) = Right sql
go (errs, _) = Left errs
-- like parseMigration, but call error or return the CautiousMigration
parseMigration' :: Monad m => Migration m -> m (CautiousMigration)
parseMigration' m = do
x <- parseMigration m
case x of
Left errs -> error $ unlines $ map unpack errs
Right sql -> return sql
printMigration :: (MBCIO m, MonadIO m) => Migration (SqlPersist m) -> SqlPersist m ()
printMigration m = do
mig <- parseMigration' m
mapM_ (liftIO . Data.Text.IO.putStrLn . flip snoc ';') (allSql mig)
getMigration :: (MBCIO m, MonadIO m) => Migration (SqlPersist m) -> SqlPersist m [Sql]
getMigration m = do
mig <- parseMigration' m
return $ allSql mig
runMigration :: (MonadIO m, MBCIO m)
=> Migration (SqlPersist m)
-> SqlPersist m ()
runMigration m = runMigration' m False >> return ()
-- | Same as 'runMigration', but returns a list of the SQL commands executed
-- instead of printing them to stderr.
runMigrationSilent :: (MBCIO m, MonadIO m)
=> Migration (SqlPersist m)
-> SqlPersist m [Text]
runMigrationSilent m = runMigration' m True
runMigration'
:: (MBCIO m, MonadIO m)
=> Migration (SqlPersist m)
-> Bool -- ^ is silent?
-> SqlPersist m [Text]
runMigration' m silent = do
mig <- parseMigration' m
case unsafeSql mig of
[] -> mapM (executeMigrate silent) $ safeSql mig
errs -> error $ concat
[ "\n\nDatabase migration: manual intervention required.\n"
, "The following actions are considered unsafe:\n\n"
, unlines $ map (\s -> " " ++ unpack s ++ ";") $ errs
]
runMigrationUnsafe :: (MBCIO m, MonadIO m)
=> Migration (SqlPersist m)
-> SqlPersist m ()
runMigrationUnsafe m = do
mig <- parseMigration' m
mapM_ (executeMigrate False) $ allSql mig
executeMigrate :: MonadIO m => Bool -> Text -> SqlPersist m Text
executeMigrate silent s = do
unless silent $ liftIO $ hPutStrLn stderr $ "Migrating: " ++ unpack s
execute' s []
return s
migrate :: (MonadIO m, MBCIO m, PersistEntity val)
=> [EntityDef]
-> val
-> Migration (SqlPersist m)
migrate allDefs val = do
conn <- lift $ lift $ SqlPersist ask
let getter = R.getStmt' conn
res <- liftIO $ migrateSql conn allDefs getter val
either tell (lift . tell) res
-- | Perform a database commit.
commit :: MonadIO m => SqlPersist m ()
commit = do
conn <- SqlPersist ask
let getter = R.getStmt' conn
liftIO $ commitC conn getter >> begin conn getter
-- | Perform a database rollback.
rollback :: MonadIO m => SqlPersist m ()
rollback = do
conn <- SqlPersist ask
let getter = R.getStmt' conn
liftIO $ rollbackC conn getter >> begin conn getter
|