1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
|
{-# LANGUAGE CPP #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
#ifndef NO_OVERLAP
{-# LANGUAGE OverlappingInstances #-}
#endif
module Database.Persist.Sql.Class
( MonadSqlPersist (..)
, RawSql (..)
, PersistFieldSql (..)
) where
import Control.Applicative ((<$>), (<*>))
import Database.Persist
import Data.Monoid ((<>))
import Database.Persist.Sql.Types
import Control.Arrow ((&&&))
import Data.Text (Text, intercalate, pack)
import Data.Maybe (fromMaybe)
import Data.Fixed
import Data.Monoid (Monoid)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Logger (LoggingT)
import Control.Monad.Trans.Identity ( IdentityT)
import Control.Monad.Trans.List ( ListT )
import Control.Monad.Trans.Maybe ( MaybeT )
import Control.Monad.Trans.Error ( ErrorT, Error)
import Control.Monad.Trans.Cont ( ContT )
import Control.Monad.Trans.State ( StateT )
import Control.Monad.Trans.Writer ( WriterT )
import Control.Monad.Trans.RWS ( RWST )
import Control.Monad.Trans.Reader ( ReaderT, ask )
import Control.Monad.Trans.Resource ( ResourceT )
import Data.Conduit.Internal (Pipe, ConduitM)
import qualified Control.Monad.Trans.RWS.Strict as Strict ( RWST )
import qualified Control.Monad.Trans.State.Strict as Strict ( StateT )
import qualified Control.Monad.Trans.Writer.Strict as Strict ( WriterT )
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.Trans.Class (MonadTrans)
import Control.Monad.Logger (MonadLogger)
import qualified Data.Text as T
import qualified Data.Text.Lazy as TL
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Time (ZonedTime, UTCTime, TimeOfDay, Day)
import Data.Int
import Data.Word
import Data.ByteString (ByteString)
import Text.Blaze.Html (Html)
import Data.Bits (bitSize)
class (MonadIO m, MonadLogger m) => MonadSqlPersist m where
askSqlConn :: m Connection
default askSqlConn :: (MonadSqlPersist m, MonadTrans t, MonadLogger (t m))
=> t m Connection
askSqlConn = lift askSqlConn
instance (MonadIO m, MonadLogger m) => MonadSqlPersist (SqlPersistT m) where
askSqlConn = SqlPersistT ask
#define GO(T) instance (MonadSqlPersist m) => MonadSqlPersist (T m)
#define GOX(X, T) instance (X, MonadSqlPersist m) => MonadSqlPersist (T m)
GO(LoggingT)
GO(IdentityT)
GO(ListT)
GO(MaybeT)
GOX(Error e, ErrorT e)
GO(ReaderT r)
GO(ContT r)
GO(StateT s)
GO(ResourceT)
GO(Pipe l i o u)
GO(ConduitM i o)
GOX(Monoid w, WriterT w)
GOX(Monoid w, RWST r w s)
GOX(Monoid w, Strict.RWST r w s)
GO(Strict.StateT s)
GOX(Monoid w, Strict.WriterT w)
#undef GO
#undef GOX
-- | Class for data types that may be retrived from a 'rawSql'
-- query.
class RawSql a where
-- | Number of columns that this data type needs and the list
-- of substitutions for @SELECT@ placeholders @??@.
rawSqlCols :: (DBName -> Text) -> a -> (Int, [Text])
-- | A string telling the user why the column count is what
-- it is.
rawSqlColCountReason :: a -> String
-- | Transform a row of the result into the data type.
rawSqlProcessRow :: [PersistValue] -> Either Text a
instance PersistField a => RawSql (Single a) where
rawSqlCols _ _ = (1, [])
rawSqlColCountReason _ = "one column for a 'Single' data type"
rawSqlProcessRow [pv] = Single <$> fromPersistValue pv
rawSqlProcessRow _ = Left $ pack "RawSql (Single a): wrong number of columns."
instance PersistEntity a => RawSql (Entity a) where
rawSqlCols escape = ((+1) . length . entityFields &&& process) . entityDef . Just . entityVal
where
process ed = (:[]) $
intercalate ", " $
map ((name ed <>) . escape) $
(entityID ed:) $
map fieldDB $
entityFields ed
name ed = escape (entityDB ed) <> "."
rawSqlColCountReason a =
case fst (rawSqlCols (error "RawSql") a) of
1 -> "one column for an 'Entity' data type without fields"
n -> show n ++ " columns for an 'Entity' data type"
rawSqlProcessRow (idCol:ent) = Entity <$> fromPersistValue idCol
<*> fromPersistValues ent
rawSqlProcessRow _ = Left "RawSql (Entity a): wrong number of columns."
-- | Since 1.0.1.
instance RawSql a => RawSql (Maybe a) where
rawSqlCols e = rawSqlCols e . extractMaybe
rawSqlColCountReason = rawSqlColCountReason . extractMaybe
rawSqlProcessRow cols
| all isNull cols = return Nothing
| otherwise =
case rawSqlProcessRow cols of
Right v -> Right (Just v)
Left msg -> Left $ "RawSql (Maybe a): not all columns were Null " <>
"but the inner parser has failed. Its message " <>
"was \"" <> msg <> "\". Did you apply Maybe " <>
"to a tuple, perhaps? The main use case for " <>
"Maybe is to allow OUTER JOINs to be written, " <>
"in which case 'Maybe (Entity v)' is used."
where isNull PersistNull = True
isNull _ = False
instance (RawSql a, RawSql b) => RawSql (a, b) where
rawSqlCols e x = rawSqlCols e (fst x) # rawSqlCols e (snd x)
where (cnta, lsta) # (cntb, lstb) = (cnta + cntb, lsta ++ lstb)
rawSqlColCountReason x = rawSqlColCountReason (fst x) ++ ", " ++
rawSqlColCountReason (snd x)
rawSqlProcessRow =
let x = getType processRow
getType :: (z -> Either y x) -> x
getType = error "RawSql.getType"
colCountFst = fst $ rawSqlCols (error "RawSql.getType2") (fst x)
processRow row =
let (rowFst, rowSnd) = splitAt colCountFst row
in (,) <$> rawSqlProcessRow rowFst
<*> rawSqlProcessRow rowSnd
in colCountFst `seq` processRow
-- Avoids recalculating 'colCountFst'.
instance (RawSql a, RawSql b, RawSql c) => RawSql (a, b, c) where
rawSqlCols e = rawSqlCols e . from3
rawSqlColCountReason = rawSqlColCountReason . from3
rawSqlProcessRow = fmap to3 . rawSqlProcessRow
from3 :: (a,b,c) -> ((a,b),c)
from3 (a,b,c) = ((a,b),c)
to3 :: ((a,b),c) -> (a,b,c)
to3 ((a,b),c) = (a,b,c)
instance (RawSql a, RawSql b, RawSql c, RawSql d) => RawSql (a, b, c, d) where
rawSqlCols e = rawSqlCols e . from4
rawSqlColCountReason = rawSqlColCountReason . from4
rawSqlProcessRow = fmap to4 . rawSqlProcessRow
from4 :: (a,b,c,d) -> ((a,b),(c,d))
from4 (a,b,c,d) = ((a,b),(c,d))
to4 :: ((a,b),(c,d)) -> (a,b,c,d)
to4 ((a,b),(c,d)) = (a,b,c,d)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e)
=> RawSql (a, b, c, d, e) where
rawSqlCols e = rawSqlCols e . from5
rawSqlColCountReason = rawSqlColCountReason . from5
rawSqlProcessRow = fmap to5 . rawSqlProcessRow
from5 :: (a,b,c,d,e) -> ((a,b),(c,d),e)
from5 (a,b,c,d,e) = ((a,b),(c,d),e)
to5 :: ((a,b),(c,d),e) -> (a,b,c,d,e)
to5 ((a,b),(c,d),e) = (a,b,c,d,e)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e, RawSql f)
=> RawSql (a, b, c, d, e, f) where
rawSqlCols e = rawSqlCols e . from6
rawSqlColCountReason = rawSqlColCountReason . from6
rawSqlProcessRow = fmap to6 . rawSqlProcessRow
from6 :: (a,b,c,d,e,f) -> ((a,b),(c,d),(e,f))
from6 (a,b,c,d,e,f) = ((a,b),(c,d),(e,f))
to6 :: ((a,b),(c,d),(e,f)) -> (a,b,c,d,e,f)
to6 ((a,b),(c,d),(e,f)) = (a,b,c,d,e,f)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e, RawSql f,
RawSql g)
=> RawSql (a, b, c, d, e, f, g) where
rawSqlCols e = rawSqlCols e . from7
rawSqlColCountReason = rawSqlColCountReason . from7
rawSqlProcessRow = fmap to7 . rawSqlProcessRow
from7 :: (a,b,c,d,e,f,g) -> ((a,b),(c,d),(e,f),g)
from7 (a,b,c,d,e,f,g) = ((a,b),(c,d),(e,f),g)
to7 :: ((a,b),(c,d),(e,f),g) -> (a,b,c,d,e,f,g)
to7 ((a,b),(c,d),(e,f),g) = (a,b,c,d,e,f,g)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e, RawSql f,
RawSql g, RawSql h)
=> RawSql (a, b, c, d, e, f, g, h) where
rawSqlCols e = rawSqlCols e . from8
rawSqlColCountReason = rawSqlColCountReason . from8
rawSqlProcessRow = fmap to8 . rawSqlProcessRow
from8 :: (a,b,c,d,e,f,g,h) -> ((a,b),(c,d),(e,f),(g,h))
from8 (a,b,c,d,e,f,g,h) = ((a,b),(c,d),(e,f),(g,h))
to8 :: ((a,b),(c,d),(e,f),(g,h)) -> (a,b,c,d,e,f,g,h)
to8 ((a,b),(c,d),(e,f),(g,h)) = (a,b,c,d,e,f,g,h)
extractMaybe :: Maybe a -> a
extractMaybe = fromMaybe (error "Database.Persist.GenericSql.extractMaybe")
class PersistField a => PersistFieldSql a where
sqlType :: Monad m => m a -> SqlType
#ifndef NO_OVERLAP
instance PersistFieldSql String where
sqlType _ = SqlString
#endif
instance PersistFieldSql ByteString where
sqlType _ = SqlBlob
instance PersistFieldSql T.Text where
sqlType _ = SqlString
instance PersistFieldSql TL.Text where
sqlType _ = SqlString
instance PersistFieldSql Html where
sqlType _ = SqlString
instance PersistFieldSql Int where
sqlType _
| bitSize (0 :: Int) <= 32 = SqlInt32
| otherwise = SqlInt64
instance PersistFieldSql Int8 where
sqlType _ = SqlInt32
instance PersistFieldSql Int16 where
sqlType _ = SqlInt32
instance PersistFieldSql Int32 where
sqlType _ = SqlInt32
instance PersistFieldSql Int64 where
sqlType _ = SqlInt64
instance PersistFieldSql Word where
sqlType _ = SqlInt64
instance PersistFieldSql Word8 where
sqlType _ = SqlInt32
instance PersistFieldSql Word16 where
sqlType _ = SqlInt32
instance PersistFieldSql Word32 where
sqlType _ = SqlInt64
instance PersistFieldSql Word64 where
sqlType _ = SqlInt64
instance PersistFieldSql Double where
sqlType _ = SqlReal
instance PersistFieldSql Bool where
sqlType _ = SqlBool
instance PersistFieldSql Day where
sqlType _ = SqlDay
instance PersistFieldSql TimeOfDay where
sqlType _ = SqlTime
instance PersistFieldSql UTCTime where
sqlType _ = SqlDayTime
instance PersistFieldSql ZonedTime where
sqlType _ = SqlDayTimeZoned
instance PersistFieldSql a => PersistFieldSql [a] where
sqlType _ = SqlString
instance (Ord a, PersistFieldSql a) => PersistFieldSql (S.Set a) where
sqlType _ = SqlString
instance (PersistFieldSql a, PersistFieldSql b) => PersistFieldSql (a,b) where
sqlType _ = SqlString
instance PersistFieldSql v => PersistFieldSql (M.Map T.Text v) where
sqlType _ = SqlString
instance PersistFieldSql PersistValue where
sqlType _ = SqlInt64 -- since PersistValue should only be used like this for keys, which in SQL are Int64
instance PersistFieldSql Checkmark where
sqlType _ = SqlBool
instance (HasResolution a) => PersistFieldSql (Fixed a) where
sqlType a =
SqlNumeric long prec
where
prec = round $ (log $ fromIntegral $ resolution n) / (log 10 :: Double) -- FIXME: May lead to problems with big numbers
long = prec + 10 -- FIXME: Is this enough ?
n = 0
_mn = return n `asTypeOf` a
instance PersistFieldSql Rational where
sqlType _ = SqlNumeric 32 20 -- need to make this field big enough to handle Rational to Mumber string conversion for ODBC
-- perhaps a SQL user can figure this sqlType out?
-- It is really intended for MongoDB though.
instance PersistField entity => PersistFieldSql (Entity entity) where
sqlType _ = SqlOther "embedded entity, hard to type"
instance PersistFieldSql (KeyBackend SqlBackend a) where
sqlType _ = SqlInt64
|