1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
|
-----------------------------------------------------------------------------------------
{-| Module : Database.HSQL.MySQL
Copyright : (c) Krasimir Angelov 2003
License : BSD-style
Maintainer : ka2_mail@yahoo.com
Stability : provisional
Portability : portable
The module provides interface to MySQL database
-}
-----------------------------------------------------------------------------------------
#include <config.h>
module Database.HSQL.MySQL(connect, module Database.HSQL) where
import Database.HSQL
import Database.HSQL.Types
import Data.Dynamic
import Data.Bits
import Data.Char
import Foreign
import Foreign.C
import Control.Monad(when,unless)
import Control.Exception (throwDyn, finally)
import Control.Concurrent.MVar
import System.Time
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import Text.Read
#include <HsMySQL.h>
type MYSQL = Ptr ()
type MYSQL_RES = Ptr ()
type MYSQL_FIELD = Ptr ()
type MYSQL_ROW = Ptr CString
type MYSQL_LENGTHS = Ptr CULong
#if defined(_WIN32_)
#let CALLCONV = "stdcall"
#else
#let CALLCONV = "ccall"
#endif
foreign import #{CALLCONV} "HsMySQL.h mysql_init" mysql_init :: MYSQL -> IO MYSQL
foreign import #{CALLCONV} "HsMySQL.h mysql_real_connect" mysql_real_connect :: MYSQL -> CString -> CString -> CString -> CString -> CInt -> CString -> CInt -> IO MYSQL
foreign import #{CALLCONV} "HsMySQL.h mysql_close" mysql_close :: MYSQL -> IO ()
foreign import #{CALLCONV} "HsMySQL.h mysql_errno" mysql_errno :: MYSQL -> IO CInt
foreign import #{CALLCONV} "HsMySQL.h mysql_error" mysql_error :: MYSQL -> IO CString
foreign import #{CALLCONV} "HsMySQL.h mysql_query" mysql_query :: MYSQL -> CString -> IO CInt
foreign import #{CALLCONV} "HsMySQL.h mysql_use_result" mysql_use_result :: MYSQL -> IO MYSQL_RES
foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_field" mysql_fetch_field :: MYSQL_RES -> IO MYSQL_FIELD
foreign import #{CALLCONV} "HsMySQL.h mysql_free_result" mysql_free_result :: MYSQL_RES -> IO ()
foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_row" mysql_fetch_row :: MYSQL_RES -> IO MYSQL_ROW
foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_lengths" mysql_fetch_lengths :: MYSQL_RES -> IO MYSQL_LENGTHS
foreign import #{CALLCONV} "HsMySQL.h mysql_list_tables" mysql_list_tables :: MYSQL -> CString -> IO MYSQL_RES
foreign import #{CALLCONV} "HsMySQL.h mysql_list_fields" mysql_list_fields :: MYSQL -> CString -> CString -> IO MYSQL_RES
foreign import #{CALLCONV} "HsMySQL.h mysql_next_result" mysql_next_result :: MYSQL -> IO CInt
-----------------------------------------------------------------------------------------
-- routines for handling exceptions
-----------------------------------------------------------------------------------------
handleSqlError :: MYSQL -> IO a
handleSqlError pMYSQL = do
errno <- mysql_errno pMYSQL
errMsg <- mysql_error pMYSQL >>= peekCString
throwDyn (SqlError "" (fromIntegral errno) errMsg)
-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------
-- | Makes a new connection to the database server.
connect :: String -- ^ Server name
-> String -- ^ Database name
-> String -- ^ User identifier
-> String -- ^ Authentication string (password)
-> IO Connection
connect server database user authentication = do
pMYSQL <- mysql_init nullPtr
pServer <- newCString server
pDatabase <- newCString database
pUser <- newCString user
pAuthentication <- newCString authentication
res <- mysql_real_connect pMYSQL pServer pUser pAuthentication pDatabase 0 nullPtr (#const CLIENT_MULTI_STATEMENTS)
free pServer
free pDatabase
free pUser
free pAuthentication
when (res == nullPtr) (handleSqlError pMYSQL)
refFalse <- newMVar False
let connection = Connection
{ connDisconnect = mysql_close pMYSQL
, connExecute = execute pMYSQL
, connQuery = query connection pMYSQL
, connTables = tables connection pMYSQL
, connDescribe = describe connection pMYSQL
, connBeginTransaction = execute pMYSQL "begin"
, connCommitTransaction = execute pMYSQL "commit"
, connRollbackTransaction = execute pMYSQL "rollback"
, connClosed = refFalse
}
return connection
where
execute :: MYSQL -> String -> IO ()
execute pMYSQL query = do
res <- withCString query (mysql_query pMYSQL)
when (res /= 0) (handleSqlError pMYSQL)
withStatement :: Connection -> MYSQL -> MYSQL_RES -> IO Statement
withStatement conn pMYSQL pRes = do
currRow <- newMVar (nullPtr, nullPtr)
refFalse <- newMVar False
if (pRes == nullPtr)
then do
errno <- mysql_errno pMYSQL
when (errno /= 0) (handleSqlError pMYSQL)
return (Statement
{ stmtConn = conn
, stmtClose = return ()
, stmtFetch = fetch pRes currRow
, stmtGetCol = getColValue currRow
, stmtFields = []
, stmtClosed = refFalse
})
else do
fieldDefs <- getFieldDefs pRes
return (Statement
{ stmtConn = conn
, stmtClose = mysql_free_result pRes
, stmtFetch = fetch pRes currRow
, stmtGetCol = getColValue currRow
, stmtFields = fieldDefs
, stmtClosed = refFalse
})
where
getFieldDefs pRes = do
pField <- mysql_fetch_field pRes
if pField == nullPtr
then return []
else do
name <- (#peek MYSQL_FIELD, name) pField >>= peekCString
dataType <- (#peek MYSQL_FIELD, type) pField
columnSize <- (#peek MYSQL_FIELD, length) pField
flags <- (#peek MYSQL_FIELD, flags) pField
decimalDigits <- (#peek MYSQL_FIELD, decimals) pField
let sqlType = mkSqlType dataType columnSize decimalDigits
defs <- getFieldDefs pRes
return ((name,sqlType,((flags :: Int) .&. (#const NOT_NULL_FLAG)) == 0):defs)
mkSqlType :: Int -> Int -> Int -> SqlType
mkSqlType (#const FIELD_TYPE_STRING) size _ = SqlChar size
mkSqlType (#const FIELD_TYPE_VAR_STRING) size _ = SqlVarChar size
mkSqlType (#const FIELD_TYPE_DECIMAL) size prec = SqlNumeric size prec
mkSqlType (#const FIELD_TYPE_SHORT) _ _ = SqlSmallInt
mkSqlType (#const FIELD_TYPE_INT24) _ _ = SqlMedInt
mkSqlType (#const FIELD_TYPE_LONG) _ _ = SqlInteger
mkSqlType (#const FIELD_TYPE_FLOAT) _ _ = SqlReal
mkSqlType (#const FIELD_TYPE_DOUBLE) _ _ = SqlDouble
mkSqlType (#const FIELD_TYPE_TINY) _ _ = SqlTinyInt
mkSqlType (#const FIELD_TYPE_LONGLONG) _ _ = SqlBigInt
mkSqlType (#const FIELD_TYPE_DATE) _ _ = SqlDate
mkSqlType (#const FIELD_TYPE_TIME) _ _ = SqlTime
mkSqlType (#const FIELD_TYPE_TIMESTAMP) _ _ = SqlTimeStamp
mkSqlType (#const FIELD_TYPE_DATETIME) _ _ = SqlDateTime
mkSqlType (#const FIELD_TYPE_YEAR) _ _ = SqlYear
mkSqlType (#const FIELD_TYPE_BLOB) _ _ = SqlBLOB
mkSqlType (#const FIELD_TYPE_SET) _ _ = SqlSET
mkSqlType (#const FIELD_TYPE_ENUM) _ _ = SqlENUM
mkSqlType tp _ _ = SqlUnknown tp
query :: Connection -> MYSQL -> String -> IO Statement
query conn pMYSQL query = do
res <- withCString query (mysql_query pMYSQL)
when (res /= 0) (handleSqlError pMYSQL)
pRes <- getFirstResult pMYSQL
withStatement conn pMYSQL pRes
where
getFirstResult :: MYSQL -> IO MYSQL_RES
getFirstResult pMYSQL = do
pRes <- mysql_use_result pMYSQL
if pRes == nullPtr
then do
res <- mysql_next_result pMYSQL
if res == 0
then getFirstResult pMYSQL
else return nullPtr
else return pRes
fetch :: MYSQL_RES -> MVar (MYSQL_ROW, MYSQL_LENGTHS) -> IO Bool
fetch pRes currRow
| pRes == nullPtr = return False
| otherwise = modifyMVar currRow $ \(pRow, pLengths) -> do
pRow <- mysql_fetch_row pRes
pLengths <- mysql_fetch_lengths pRes
return ((pRow, pLengths), pRow /= nullPtr)
getColValue :: MVar (MYSQL_ROW, MYSQL_LENGTHS) -> Int -> FieldDef -> (SqlType -> CString -> Int -> IO (Maybe a)) -> IO (Maybe a)
getColValue currRow colNumber (name,sqlType,nullable) f = do
(row, lengths) <- readMVar currRow
pValue <- peekElemOff row colNumber
len <- fmap fromIntegral (peekElemOff lengths colNumber)
if pValue == nullPtr
then return Nothing
else do
mv <- f sqlType pValue len
case mv of
Just v -> return (Just v)
Nothing -> throwDyn (SqlBadTypeCast name sqlType)
tables :: Connection -> MYSQL -> IO [String]
tables conn pMYSQL = do
pRes <- mysql_list_tables pMYSQL nullPtr
stmt <- withStatement conn pMYSQL pRes
-- SQLTables returns:
-- Column name # Type
-- Tables_in_xx 0 VARCHAR
collectRows (\stmt -> do
mb_v <- stmtGetCol stmt 0 ("Tables", SqlVarChar 0, False) fromNonNullSqlCStringLen
return (case mb_v of { Nothing -> ""; Just a -> a })) stmt
describe :: Connection -> MYSQL -> String -> IO [FieldDef]
describe conn pMYSQL table = do
pRes <- withCString table (\table -> mysql_list_fields pMYSQL table nullPtr)
stmt <- withStatement conn pMYSQL pRes
return (getFieldsTypes stmt)
|