1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
module DB.HSQL.MySQL.Functions where
import Foreign((.&.),peekByteOff,nullPtr,peekElemOff)
import Foreign.C(CInt,CString,peekCString)
import Control.Concurrent.MVar(MVar,newMVar,modifyMVar,readMVar)
import Control.Exception (throw)
import Control.Monad(when)
import Database.HSQL.Types(FieldDef,Statement(..),Connection(..),SqlError(..))
import DB.HSQL.MySQL.Type(MYSQL,MYSQL_RES,MYSQL_FIELD,MYSQL_ROW,MYSQL_LENGTHS
,mkSqlType)
#include <HsMySQL.h>
#ifdef mingw32_HOST_OS
#let CALLCONV = "stdcall"
#else
#let CALLCONV = "ccall"
#endif
-- |
foreign import #{CALLCONV} "HsMySQL.h mysql_init"
mysql_init :: MYSQL -> IO MYSQL
foreign import #{CALLCONV} "HsMySQL.h mysql_real_connect"
mysql_real_connect :: MYSQL -> CString -> CString -> CString -> CString -> CInt -> CString -> CInt -> IO MYSQL
foreign import #{CALLCONV} "HsMySQL.h mysql_close"
mysql_close :: MYSQL -> IO ()
foreign import #{CALLCONV} "HsMySQL.h mysql_errno"
mysql_errno :: MYSQL -> IO CInt
foreign import #{CALLCONV} "HsMySQL.h mysql_error"
mysql_error :: MYSQL -> IO CString
foreign import #{CALLCONV} "HsMySQL.h mysql_query"
mysql_query :: MYSQL -> CString -> IO CInt
foreign import #{CALLCONV} "HsMySQL.h mysql_use_result"
mysql_use_result :: MYSQL -> IO MYSQL_RES
foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_field"
mysql_fetch_field :: MYSQL_RES -> IO MYSQL_FIELD
foreign import #{CALLCONV} "HsMySQL.h mysql_free_result"
mysql_free_result :: MYSQL_RES -> IO ()
foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_row"
mysql_fetch_row :: MYSQL_RES -> IO MYSQL_ROW
foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_lengths"
mysql_fetch_lengths :: MYSQL_RES -> IO MYSQL_LENGTHS
foreign import #{CALLCONV} "HsMySQL.h mysql_list_tables"
mysql_list_tables :: MYSQL -> CString -> IO MYSQL_RES
foreign import #{CALLCONV} "HsMySQL.h mysql_list_fields"
mysql_list_fields :: MYSQL -> CString -> CString -> IO MYSQL_RES
foreign import #{CALLCONV} "HsMySQL.h mysql_next_result"
mysql_next_result :: MYSQL -> IO CInt
-- |
withStatement :: Connection -> MYSQL -> MYSQL_RES -> IO Statement
withStatement conn pMYSQL pRes = do
currRow <- newMVar (nullPtr, nullPtr)
refFalse <- newMVar False
if (pRes == nullPtr)
then do
errno <- mysql_errno pMYSQL
when (errno /= 0) (handleSqlError pMYSQL)
return Statement { stmtConn = conn
, stmtClose = return ()
, stmtFetch = fetch pRes currRow
, stmtGetCol = getColValue currRow
, stmtFields = []
, stmtClosed = refFalse }
else do
fieldDefs <- getFieldDefs pRes
return Statement { stmtConn = conn
, stmtClose = mysql_free_result pRes
, stmtFetch = fetch pRes currRow
, stmtGetCol = getColValue currRow
, stmtFields = fieldDefs
, stmtClosed = refFalse }
-- |
getColValue :: MVar (MYSQL_ROW, MYSQL_LENGTHS)
-> Int
-> FieldDef
-> (FieldDef -> CString -> Int -> IO a)
-> IO a
getColValue currRow colNumber fieldDef f = do
(row, lengths) <- readMVar currRow
pValue <- peekElemOff row colNumber
len <- fmap fromIntegral (peekElemOff lengths colNumber)
f fieldDef pValue len
-- |
getFieldDefs pRes = do
pField <- mysql_fetch_field pRes
if pField == nullPtr
then return []
else do
name <- (#peek MYSQL_FIELD, name) pField >>= peekCString
dataType <- (#peek MYSQL_FIELD, type) pField
columnSize <- (#peek MYSQL_FIELD, length) pField
flags <- (#peek MYSQL_FIELD, flags) pField
decimalDigits <- (#peek MYSQL_FIELD, decimals) pField
let sqlType = mkSqlType dataType columnSize decimalDigits
defs <- getFieldDefs pRes
return ( (name,sqlType,((flags :: Int) .&. (#const NOT_NULL_FLAG)) == 0)
: defs )
-- |
fetch :: MYSQL_RES
-> MVar (MYSQL_ROW, MYSQL_LENGTHS)
-> IO Bool
fetch pRes currRow
| pRes == nullPtr = return False
| otherwise = modifyMVar currRow $ \(pRow, pLengths) -> do
pRow <- mysql_fetch_row pRes
pLengths <- mysql_fetch_lengths pRes
return ((pRow, pLengths), pRow /= nullPtr)
-- |
mysqlDefaultConnectFlags:: CInt
mysqlDefaultConnectFlags = #const MYSQL_DEFAULT_CONNECT_FLAGS
------------------------------------------------------------------------------
-- routines for handling exceptions
------------------------------------------------------------------------------
-- |
handleSqlError :: MYSQL -> IO a
handleSqlError pMYSQL = do
errno <- mysql_errno pMYSQL
errMsg <- mysql_error pMYSQL >>= peekCString
throw (SqlError "" (fromIntegral errno) errMsg)
|