File: Internal.hs

package info (click to toggle)
haskell-persistent-postgresql 2.13.6.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 272 kB
  • sloc: haskell: 3,281; makefile: 2
file content (284 lines) | stat: -rw-r--r-- 11,983 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ViewPatterns #-}

module Database.Persist.Postgresql.Internal
    ( P(..)
    , PgInterval(..)
    , getGetter
    ) where

import qualified Database.PostgreSQL.Simple as PG
import qualified Database.PostgreSQL.Simple.FromField as PGFF
import qualified Database.PostgreSQL.Simple.Internal as PG
import qualified Database.PostgreSQL.Simple.ToField as PGTF
import qualified Database.PostgreSQL.Simple.TypeInfo.Static as PS
import qualified Database.PostgreSQL.Simple.Types as PG

import qualified Blaze.ByteString.Builder.Char8 as BBB
import qualified Data.Attoparsec.ByteString.Char8 as P
import Data.Bits ((.&.))
import Data.ByteString (ByteString)
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Char8 as B8
import Data.Char (ord)
import Data.Data (Typeable)
import Data.Fixed (Fixed(..), Pico)
import Data.Int (Int64)
import qualified Data.IntMap as I
import Data.Maybe (fromMaybe)
import Data.String.Conversions.Monomorphic (toStrictByteString)
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.Time (NominalDiffTime, localTimeToUTC, utc)

import Database.Persist.Sql

-- | Newtype used to avoid orphan instances for @postgresql-simple@ classes.
--
-- @since 2.13.2.0
newtype P = P { unP :: PersistValue }

instance PGTF.ToField P where
    toField (P (PersistText t))        = PGTF.toField t
    toField (P (PersistByteString bs)) = PGTF.toField (PG.Binary bs)
    toField (P (PersistInt64 i))       = PGTF.toField i
    toField (P (PersistDouble d))      = PGTF.toField d
    toField (P (PersistRational r))    = PGTF.Plain $
                                         BBB.fromString $
                                         show (fromRational r :: Pico) --  FIXME: Too Ambigous, can not select precision without information about field
    toField (P (PersistBool b))        = PGTF.toField b
    toField (P (PersistDay d))         = PGTF.toField d
    toField (P (PersistTimeOfDay t))   = PGTF.toField t
    toField (P (PersistUTCTime t))     = PGTF.toField t
    toField (P PersistNull)            = PGTF.toField PG.Null
    toField (P (PersistList l))        = PGTF.toField $ listToJSON l
    toField (P (PersistMap m))         = PGTF.toField $ mapToJSON m
    toField (P (PersistLiteral_ DbSpecific s))  = PGTF.toField (Unknown s)
    toField (P (PersistLiteral_ Unescaped l))     = PGTF.toField (UnknownLiteral l)
    toField (P (PersistLiteral_ Escaped e)) = PGTF.toField (Unknown e)
    toField (P (PersistArray a))       = PGTF.toField $ PG.PGArray $ P <$> a
    toField (P (PersistObjectId _))    =
        error "Refusing to serialize a PersistObjectId to a PostgreSQL value"

instance PGFF.FromField P where
    fromField field mdata = fmap P $ case mdata of
      -- If we try to simply decode based on oid, we will hit unexpected null
      -- errors.
      Nothing -> pure PersistNull
      data' -> getGetter (PGFF.typeOid field) field data'

newtype Unknown = Unknown { unUnknown :: ByteString }
  deriving (Eq, Show, Read, Ord)

instance PGFF.FromField Unknown where
    fromField f mdata =
      case mdata of
        Nothing  -> PGFF.returnError PGFF.UnexpectedNull f "Database.Persist.Postgresql/PGFF.FromField Unknown"
        Just dat -> return (Unknown dat)

instance PGTF.ToField Unknown where
    toField (Unknown a) = PGTF.Escape a

newtype UnknownLiteral = UnknownLiteral { unUnknownLiteral :: ByteString }
  deriving (Eq, Show, Read, Ord, Typeable)

instance PGFF.FromField UnknownLiteral where
    fromField f mdata =
      case mdata of
        Nothing  -> PGFF.returnError PGFF.UnexpectedNull f "Database.Persist.Postgresql/PGFF.FromField UnknownLiteral"
        Just dat -> return (UnknownLiteral dat)

instance PGTF.ToField UnknownLiteral where
    toField (UnknownLiteral a) = PGTF.Plain $ BB.byteString a

type Getter a = PGFF.FieldParser a

convertPV :: PGFF.FromField a => (a -> b) -> Getter b
convertPV f = (fmap f .) . PGFF.fromField

builtinGetters :: I.IntMap (Getter PersistValue)
builtinGetters = I.fromList
    [ (k PS.bool,        convertPV PersistBool)
    , (k PS.bytea,       convertPV (PersistByteString . unBinary))
    , (k PS.char,        convertPV PersistText)
    , (k PS.name,        convertPV PersistText)
    , (k PS.int8,        convertPV PersistInt64)
    , (k PS.int2,        convertPV PersistInt64)
    , (k PS.int4,        convertPV PersistInt64)
    , (k PS.text,        convertPV PersistText)
    , (k PS.xml,         convertPV (PersistByteString . unUnknown))
    , (k PS.float4,      convertPV PersistDouble)
    , (k PS.float8,      convertPV PersistDouble)
    , (k PS.money,       convertPV PersistRational)
    , (k PS.bpchar,      convertPV PersistText)
    , (k PS.varchar,     convertPV PersistText)
    , (k PS.date,        convertPV PersistDay)
    , (k PS.time,        convertPV PersistTimeOfDay)
    , (k PS.timestamp,   convertPV (PersistUTCTime. localTimeToUTC utc))
    , (k PS.timestamptz, convertPV PersistUTCTime)
    , (k PS.interval,    convertPV (PersistLiteralEscaped . pgIntervalToBs))
    , (k PS.bit,         convertPV PersistInt64)
    , (k PS.varbit,      convertPV PersistInt64)
    , (k PS.numeric,     convertPV PersistRational)
    , (k PS.void,        \_ _ -> return PersistNull)
    , (k PS.json,        convertPV (PersistByteString . unUnknown))
    , (k PS.jsonb,       convertPV (PersistByteString . unUnknown))
    , (k PS.unknown,     convertPV (PersistByteString . unUnknown))

    -- Array types: same order as above.
    -- The OIDs were taken from pg_type.
    , (1000,             listOf PersistBool)
    , (1001,             listOf (PersistByteString . unBinary))
    , (1002,             listOf PersistText)
    , (1003,             listOf PersistText)
    , (1016,             listOf PersistInt64)
    , (1005,             listOf PersistInt64)
    , (1007,             listOf PersistInt64)
    , (1009,             listOf PersistText)
    , (143,              listOf (PersistByteString . unUnknown))
    , (1021,             listOf PersistDouble)
    , (1022,             listOf PersistDouble)
    , (1023,             listOf PersistUTCTime)
    , (1024,             listOf PersistUTCTime)
    , (791,              listOf PersistRational)
    , (1014,             listOf PersistText)
    , (1015,             listOf PersistText)
    , (1182,             listOf PersistDay)
    , (1183,             listOf PersistTimeOfDay)
    , (1115,             listOf PersistUTCTime)
    , (1185,             listOf PersistUTCTime)
    , (1187,             listOf (PersistLiteralEscaped . pgIntervalToBs))
    , (1561,             listOf PersistInt64)
    , (1563,             listOf PersistInt64)
    , (1231,             listOf PersistRational)
    -- no array(void) type
    , (2951,             listOf (PersistLiteralEscaped . unUnknown))
    , (199,              listOf (PersistByteString . unUnknown))
    , (3807,             listOf (PersistByteString . unUnknown))
    -- no array(unknown) either
    ]
    where
        k (PGFF.typoid -> i) = PG.oid2int i
        -- A @listOf f@ will use a @PGArray (Maybe T)@ to convert
        -- the values to Haskell-land.  The @Maybe@ is important
        -- because the usual way of checking NULLs
        -- (c.f. withStmt') won't check for NULL inside
        -- arrays---or any other compound structure for that matter.
        listOf f = convertPV (PersistList . map (nullable f) . PG.fromPGArray)
          where nullable = maybe PersistNull

-- | Get the field parser corresponding to the given 'PG.Oid'.
--
-- For example, pass in the 'PG.Oid' of 'PS.bool', and you will get back a
-- field parser which parses boolean values in the table into 'PersistBool's.
--
-- @since 2.13.2.0
getGetter :: PG.Oid -> Getter PersistValue
getGetter oid
  = fromMaybe defaultGetter $ I.lookup (PG.oid2int oid) builtinGetters
  where defaultGetter = convertPV (PersistLiteralEscaped . unUnknown)

unBinary :: PG.Binary a -> a
unBinary (PG.Binary x) = x

-- | Represent Postgres interval using NominalDiffTime
--
-- @since 2.11.0.0
newtype PgInterval = PgInterval { getPgInterval :: NominalDiffTime }
  deriving (Eq, Show)

pgIntervalToBs :: PgInterval -> ByteString
pgIntervalToBs = toStrictByteString . show . getPgInterval

instance PGTF.ToField PgInterval where
    toField (PgInterval t) = PGTF.toField t

instance PGFF.FromField PgInterval where
    fromField f mdata =
      if PGFF.typeOid f /= PS.typoid PS.interval
        then PGFF.returnError PGFF.Incompatible f ""
        else case mdata of
          Nothing  -> PGFF.returnError PGFF.UnexpectedNull f ""
          Just dat -> case P.parseOnly (nominalDiffTime <* P.endOfInput) dat of
            Left msg  ->  PGFF.returnError PGFF.ConversionFailed f msg
            Right t   -> return $ PgInterval t

      where
        toPico :: Integer -> Pico
        toPico = MkFixed

        -- Taken from Database.PostgreSQL.Simple.Time.Internal.Parser
        twoDigits :: P.Parser Int
        twoDigits = do
          a <- P.digit
          b <- P.digit
          let c2d c = ord c .&. 15
          return $! c2d a * 10 + c2d b

        -- Taken from Database.PostgreSQL.Simple.Time.Internal.Parser
        seconds :: P.Parser Pico
        seconds = do
          real <- twoDigits
          mc <- P.peekChar
          case mc of
            Just '.' -> do
              t <- P.anyChar *> P.takeWhile1 P.isDigit
              return $! parsePicos (fromIntegral real) t
            _ -> return $! fromIntegral real
         where
          parsePicos :: Int64 -> B8.ByteString -> Pico
          parsePicos a0 t = toPico (fromIntegral (t' * 10^n))
            where n  = max 0 (12 - B8.length t)
                  t' = B8.foldl' (\a c -> 10 * a + fromIntegral (ord c .&. 15)) a0
                                 (B8.take 12 t)

        parseSign :: P.Parser Bool
        parseSign = P.choice [P.char '-' >> return True, return False]

        -- Db stores it in [-]HHH:MM:SS.[SSSS]
        -- For example, nominalDay is stored as 24:00:00
        interval :: P.Parser (Bool, Int, Int, Pico)
        interval = do
            s  <- parseSign
            h  <- P.decimal <* P.char ':'
            m  <- twoDigits <* P.char ':'
            ss <- seconds
            if m < 60 && ss <= 60
                then return (s, h, m, ss)
                else fail "Invalid interval"

        nominalDiffTime :: P.Parser NominalDiffTime
        nominalDiffTime = do
          (s, h, m, ss) <- interval
          let pico   = ss + 60 * (fromIntegral m) + 60 * 60 * (fromIntegral (abs h))
          return . fromRational . toRational $ if s then (-pico) else pico

fromPersistValueError :: Text -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64"
                      -> Text -- ^ Database type(s), should appear different from Haskell name, e.g. "integer" or "INT", not "Int".
                      -> PersistValue -- ^ Incorrect value
                      -> Text -- ^ Error message
fromPersistValueError haskellType databaseType received = T.concat
    [ "Failed to parse Haskell type `"
    , haskellType
    , "`; expected "
    , databaseType
    , " from database, but received: "
    , T.pack (show received)
    , ". Potential solution: Check that your database schema matches your Persistent model definitions."
    ]

instance PersistField PgInterval where
    toPersistValue = PersistLiteralEscaped . pgIntervalToBs
    fromPersistValue (PersistLiteral_ DbSpecific bs) =
        fromPersistValue (PersistLiteralEscaped bs)
    fromPersistValue x@(PersistLiteral_ Escaped bs) =
      case P.parseOnly (P.signed P.rational <* P.char 's' <* P.endOfInput) bs of
        Left _  -> Left $ fromPersistValueError "PgInterval" "Interval" x
        Right i -> Right $ PgInterval i
    fromPersistValue x = Left $ fromPersistValueError "PgInterval" "Interval" x

instance PersistFieldSql PgInterval where
  sqlType _ = SqlOther "interval"