File: Functions.hsc

package info (click to toggle)
haskell-hsql-mysql 1.8.1-4
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 84 kB
  • sloc: haskell: 78; ansic: 12; makefile: 3
file content (140 lines) | stat: -rw-r--r-- 4,635 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
module DB.HSQL.MySQL.Functions where

import Foreign((.&.),peekByteOff,nullPtr,peekElemOff)
import Foreign.C(CInt,CString,peekCString)
import Control.Concurrent.MVar(MVar,newMVar,modifyMVar,readMVar)
import Control.Exception (throw)
import Control.Monad(when)

import Database.HSQL.Types(FieldDef,Statement(..),Connection(..),SqlError(..))
import DB.HSQL.MySQL.Type(MYSQL,MYSQL_RES,MYSQL_FIELD,MYSQL_ROW,MYSQL_LENGTHS
                         ,mkSqlType)

#include <HsMySQL.h>

#ifdef mingw32_HOST_OS
#let CALLCONV = "stdcall"
#else
#let CALLCONV = "ccall"
#endif

-- |
foreign import #{CALLCONV} "HsMySQL.h mysql_init"
  mysql_init :: MYSQL -> IO MYSQL

foreign import #{CALLCONV} "HsMySQL.h mysql_real_connect"
  mysql_real_connect :: MYSQL -> CString -> CString -> CString -> CString -> CInt -> CString -> CInt -> IO MYSQL

foreign import #{CALLCONV} "HsMySQL.h mysql_close"
  mysql_close :: MYSQL -> IO ()

foreign import #{CALLCONV} "HsMySQL.h mysql_errno"
  mysql_errno :: MYSQL -> IO CInt

foreign import #{CALLCONV} "HsMySQL.h mysql_error"
  mysql_error :: MYSQL -> IO CString

foreign import #{CALLCONV} "HsMySQL.h mysql_query"
  mysql_query :: MYSQL -> CString -> IO CInt

foreign import #{CALLCONV} "HsMySQL.h mysql_use_result"
  mysql_use_result :: MYSQL -> IO MYSQL_RES

foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_field"
  mysql_fetch_field :: MYSQL_RES -> IO MYSQL_FIELD

foreign import #{CALLCONV} "HsMySQL.h mysql_free_result"
  mysql_free_result :: MYSQL_RES -> IO ()

foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_row"
  mysql_fetch_row :: MYSQL_RES -> IO MYSQL_ROW

foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_lengths"
  mysql_fetch_lengths :: MYSQL_RES -> IO MYSQL_LENGTHS

foreign import #{CALLCONV} "HsMySQL.h mysql_list_tables"
  mysql_list_tables :: MYSQL -> CString -> IO MYSQL_RES

foreign import #{CALLCONV} "HsMySQL.h mysql_list_fields"
  mysql_list_fields :: MYSQL -> CString -> CString -> IO MYSQL_RES

foreign import #{CALLCONV} "HsMySQL.h mysql_next_result"
  mysql_next_result :: MYSQL -> IO CInt

-- |
withStatement :: Connection -> MYSQL -> MYSQL_RES -> IO Statement
withStatement conn pMYSQL pRes = do
  currRow  <- newMVar (nullPtr, nullPtr)
  refFalse <- newMVar False
  if (pRes == nullPtr)
    then do
      errno <- mysql_errno pMYSQL
      when (errno /= 0) (handleSqlError pMYSQL)
      return Statement { stmtConn   = conn
		       , stmtClose  = return ()
		       , stmtFetch  = fetch pRes currRow
		       , stmtGetCol = getColValue currRow
		       , stmtFields = []
		       , stmtClosed = refFalse }
    else do
      fieldDefs <- getFieldDefs pRes
      return Statement { stmtConn   = conn
		       , stmtClose  = mysql_free_result pRes
		       , stmtFetch  = fetch pRes currRow
		       , stmtGetCol = getColValue currRow
		       , stmtFields = fieldDefs
		       , stmtClosed = refFalse }

-- |
getColValue :: MVar (MYSQL_ROW, MYSQL_LENGTHS) 
            -> Int 
            -> FieldDef 
            -> (FieldDef -> CString -> Int -> IO a) 
            -> IO a
getColValue currRow colNumber fieldDef f = do
  (row, lengths) <- readMVar currRow
  pValue <- peekElemOff row colNumber
  len <- fmap fromIntegral (peekElemOff lengths colNumber)
  f fieldDef pValue len

-- |
getFieldDefs pRes = do
  pField <- mysql_fetch_field pRes
  if pField == nullPtr
    then return []
    else do
      name <- (#peek MYSQL_FIELD, name) pField >>= peekCString
      dataType <-  (#peek MYSQL_FIELD, type) pField
      columnSize <-  (#peek MYSQL_FIELD, length) pField
      flags <-  (#peek MYSQL_FIELD, flags) pField
      decimalDigits <-  (#peek MYSQL_FIELD, decimals) pField
      let sqlType = mkSqlType dataType columnSize decimalDigits
      defs <- getFieldDefs pRes
      return ( (name,sqlType,((flags :: Int) .&. (#const NOT_NULL_FLAG)) == 0)
             : defs )

-- |
fetch :: MYSQL_RES 
      -> MVar (MYSQL_ROW, MYSQL_LENGTHS) 
      -> IO Bool
fetch pRes currRow
    | pRes == nullPtr = return False
    | otherwise = modifyMVar currRow $ \(pRow, pLengths) -> do
	pRow <- mysql_fetch_row pRes
	pLengths <- mysql_fetch_lengths pRes
	return ((pRow, pLengths), pRow /= nullPtr)

-- |
mysqlDefaultConnectFlags:: CInt
mysqlDefaultConnectFlags = #const MYSQL_DEFAULT_CONNECT_FLAGS

------------------------------------------------------------------------------
-- routines for handling exceptions
------------------------------------------------------------------------------
-- |
handleSqlError :: MYSQL -> IO a
handleSqlError pMYSQL = do
	errno <- mysql_errno pMYSQL
	errMsg <- mysql_error pMYSQL >>= peekCString
	throw (SqlError "" (fromIntegral errno) errMsg)