1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981
|
# mssql.py
"""MSSQL backend, thru either pymssq, adodbapi or pyodbc interfaces.
* ``IDENTITY`` columns are supported by using SA ``schema.Sequence()``
objects. In other words::
Table('test', mss_engine,
Column('id', Integer, Sequence('blah',100,10), primary_key=True),
Column('name', String(20))
).create()
would yield::
CREATE TABLE test (
id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
name VARCHAR(20)
)
Note that the start & increment values for sequences are optional
and will default to 1,1.
* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
``INSERT`` s)
* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
* ``select.limit`` implemented as ``SELECT TOP n``
Known issues / TODO:
* No support for more than one ``IDENTITY`` column per table
* pymssql has problems with binary and unicode data that this module
does **not** work around
"""
import sys, StringIO, string, types, re, datetime, random
import sqlalchemy.sql as sql
import sqlalchemy.engine as engine
import sqlalchemy.engine.default as default
import sqlalchemy.schema as schema
import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
import sys
class MSNumeric(sqltypes.Numeric):
def convert_result_value(self, value, dialect):
return value
def convert_bind_param(self, value, dialect):
if value is None:
# Not sure that this exception is needed
return value
else:
return str(value)
def get_col_spec(self):
if self.precision is None:
return "NUMERIC"
else:
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
class MSFloat(sqltypes.Float):
def get_col_spec(self):
return "FLOAT(%(precision)s)" % {'precision': self.precision}
def convert_bind_param(self, value, dialect):
"""By converting to string, we can use Decimal types round-trip."""
if not value is None:
return str(value)
return None
class MSInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
class MSBigInteger(MSInteger):
def get_col_spec(self):
return "BIGINT"
class MSTinyInteger(MSInteger):
def get_col_spec(self):
return "TINYINT"
class MSSmallInteger(MSInteger):
def get_col_spec(self):
return "SMALLINT"
class MSDateTime(sqltypes.DateTime):
def __init__(self, *a, **kw):
super(MSDateTime, self).__init__(False)
def get_col_spec(self):
return "DATETIME"
class MSDate(sqltypes.Date):
def __init__(self, *a, **kw):
super(MSDate, self).__init__(False)
def get_col_spec(self):
return "SMALLDATETIME"
class MSTime(sqltypes.Time):
__zero_date = datetime.date(1900, 1, 1)
def __init__(self, *a, **kw):
super(MSTime, self).__init__(False)
def get_col_spec(self):
return "DATETIME"
def convert_bind_param(self, value, dialect):
if isinstance(value, datetime.datetime):
value = datetime.datetime.combine(self.__zero_date, value.time())
elif isinstance(value, datetime.time):
value = datetime.datetime.combine(self.__zero_date, value)
return value
def convert_result_value(self, value, dialect):
if isinstance(value, datetime.datetime):
return value.time()
elif isinstance(value, datetime.date):
return datetime.time(0, 0, 0)
return value
class MSDateTime_adodbapi(MSDateTime):
def convert_result_value(self, value, dialect):
# adodbapi will return datetimes with empty time values as datetime.date() objects.
# Promote them back to full datetime.datetime()
if value and not hasattr(value, 'second'):
return datetime.datetime(value.year, value.month, value.day)
return value
class MSDateTime_pyodbc(MSDateTime):
def convert_bind_param(self, value, dialect):
if value and not hasattr(value, 'second'):
return datetime.datetime(value.year, value.month, value.day)
else:
return value
class MSDate_pyodbc(MSDate):
def convert_bind_param(self, value, dialect):
if value and not hasattr(value, 'second'):
return datetime.datetime(value.year, value.month, value.day)
else:
return value
def convert_result_value(self, value, dialect):
# pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date()
if value and hasattr(value, 'second'):
return value.date()
else:
return value
class MSDate_pymssql(MSDate):
def convert_result_value(self, value, dialect):
# pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date()
if value and hasattr(value, 'second'):
return value.date()
else:
return value
class MSText(sqltypes.TEXT):
def get_col_spec(self):
if self.dialect.text_as_varchar:
return "VARCHAR(max)"
else:
return "TEXT"
class MSString(sqltypes.String):
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
class MSNVarchar(sqltypes.Unicode):
def get_col_spec(self):
if self.length:
return "NVARCHAR(%(length)s)" % {'length' : self.length}
elif self.dialect.text_as_varchar:
return "NVARCHAR(max)"
else:
return "NTEXT"
class AdoMSNVarchar(MSNVarchar):
"""overrides bindparam/result processing to not convert any unicode strings"""
def convert_bind_param(self, value, dialect):
return value
def convert_result_value(self, value, dialect):
return value
class MSChar(sqltypes.CHAR):
def get_col_spec(self):
return "CHAR(%(length)s)" % {'length' : self.length}
class MSNChar(sqltypes.NCHAR):
def get_col_spec(self):
return "NCHAR(%(length)s)" % {'length' : self.length}
class MSBinary(sqltypes.Binary):
def get_col_spec(self):
return "IMAGE"
class MSBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BIT"
def convert_result_value(self, value, dialect):
if value is None:
return None
return value and True or False
def convert_bind_param(self, value, dialect):
if value is True:
return 1
elif value is False:
return 0
elif value is None:
return None
else:
return value and True or False
class MSTimeStamp(sqltypes.TIMESTAMP):
def get_col_spec(self):
return "TIMESTAMP"
class MSMoney(sqltypes.TypeEngine):
def get_col_spec(self):
return "MONEY"
class MSSmallMoney(MSMoney):
def get_col_spec(self):
return "SMALLMONEY"
class MSUniqueIdentifier(sqltypes.TypeEngine):
def get_col_spec(self):
return "UNIQUEIDENTIFIER"
class MSVariant(sqltypes.TypeEngine):
def get_col_spec(self):
return "SQL_VARIANT"
def descriptor():
return {'name':'mssql',
'description':'MSSQL',
'arguments':[
('user',"Database Username",None),
('password',"Database Password",None),
('db',"Database Name",None),
('host',"Hostname", None),
]}
class MSSQLExecutionContext(default.DefaultExecutionContext):
def __init__(self, *args, **kwargs):
self.IINSERT = self.HASIDENT = False
super(MSSQLExecutionContext, self).__init__(*args, **kwargs)
def _has_implicit_sequence(self, column):
if column.primary_key and column.autoincrement:
if isinstance(column.type, sqltypes.Integer) and not column.foreign_key:
if column.default is None or (isinstance(column.default, schema.Sequence) and \
column.default.optional):
return True
return False
def pre_exec(self):
"""MS-SQL has a special mode for inserting non-NULL values
into IDENTITY columns.
Activate it if the feature is turned on and needed.
"""
if self.compiled.isinsert:
tbl = self.compiled.statement.table
if not hasattr(tbl, 'has_sequence'):
tbl.has_sequence = None
for column in tbl.c:
if getattr(column, 'sequence', False) or self._has_implicit_sequence(column):
tbl.has_sequence = column
break
self.HASIDENT = bool(tbl.has_sequence)
if self.dialect.auto_identity_insert and self.HASIDENT:
if isinstance(self.compiled_parameters, list):
self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0]
else:
self.IINSERT = tbl.has_sequence.key in self.compiled_parameters
else:
self.IINSERT = False
if self.IINSERT:
self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.dialect.preparer().format_table(self.compiled.statement.table))
super(MSSQLExecutionContext, self).pre_exec()
def post_exec(self):
"""Turn off the INDENTITY_INSERT mode if it's been activated,
and fetch recently inserted IDENTIFY values (works only for
one column).
"""
if self.compiled.isinsert:
if self.IINSERT:
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.preparer().format_table(self.compiled.statement.table))
self.IINSERT = False
elif self.HASIDENT:
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
if self.dialect.use_scope_identity:
self.cursor.execute("SELECT scope_identity() AS lastrowid")
else:
self.cursor.execute("SELECT @@identity AS lastrowid")
row = self.cursor.fetchone()
self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
# print "LAST ROW ID", self._last_inserted_ids
self.HASIDENT = False
super(MSSQLExecutionContext, self).post_exec()
class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
def pre_exec(self):
"""execute "set nocount on" on all connections, as a partial
workaround for multiple result set issues."""
if not getattr(self.connection, 'pyodbc_done_nocount', False):
self.connection.execute('SET nocount ON')
self.connection.pyodbc_done_nocount = True
super(MSSQLExecutionContext_pyodbc, self).pre_exec()
# where appropriate, issue "select scope_identity()" in the same statement
if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
self.statement += "; select scope_identity()"
def post_exec(self):
if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
# do nothing - id was fetched in dialect.do_execute()
self.HASIDENT = False
else:
super(MSSQLExecutionContext_pyodbc, self).post_exec()
class MSSQLDialect(ansisql.ANSIDialect):
colspecs = {
sqltypes.Unicode : MSNVarchar,
sqltypes.Integer : MSInteger,
sqltypes.Smallinteger: MSSmallInteger,
sqltypes.Numeric : MSNumeric,
sqltypes.Float : MSFloat,
sqltypes.DateTime : MSDateTime,
sqltypes.Date : MSDate,
sqltypes.Time : MSTime,
sqltypes.String : MSString,
sqltypes.Binary : MSBinary,
sqltypes.Boolean : MSBoolean,
sqltypes.TEXT : MSText,
sqltypes.CHAR: MSChar,
sqltypes.NCHAR: MSNChar,
sqltypes.TIMESTAMP: MSTimeStamp,
}
ischema_names = {
'int' : MSInteger,
'bigint': MSBigInteger,
'smallint' : MSSmallInteger,
'tinyint' : MSTinyInteger,
'varchar' : MSString,
'nvarchar' : MSNVarchar,
'char' : MSChar,
'nchar' : MSNChar,
'text' : MSText,
'ntext' : MSText,
'decimal' : MSNumeric,
'numeric' : MSNumeric,
'float' : MSFloat,
'datetime' : MSDateTime,
'smalldatetime' : MSDate,
'binary' : MSBinary,
'varbinary' : MSBinary,
'bit': MSBoolean,
'real' : MSFloat,
'image' : MSBinary,
'timestamp': MSTimeStamp,
'money': MSMoney,
'smallmoney': MSSmallMoney,
'uniqueidentifier': MSUniqueIdentifier,
'sql_variant': MSVariant,
}
def __new__(cls, dbapi=None, *args, **kwargs):
if cls != MSSQLDialect:
return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
if dbapi:
dialect = dialect_mapping.get(dbapi.__name__)
return dialect(*args, **kwargs)
else:
return object.__new__(cls, *args, **kwargs)
def __init__(self, auto_identity_insert=True, **params):
super(MSSQLDialect, self).__init__(**params)
self.auto_identity_insert = auto_identity_insert
self.text_as_varchar = False
self.use_scope_identity = False
self.set_default_schema_name("dbo")
def dbapi(cls, module_name=None):
if module_name:
try:
dialect_cls = dialect_mapping[module_name]
return dialect_cls.import_dbapi()
except KeyError:
raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
else:
for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]:
try:
return dialect_cls.import_dbapi()
except ImportError, e:
pass
else:
raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
dbapi = classmethod(dbapi)
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
opts.update(url.query)
if opts.has_key('auto_identity_insert'):
self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert')))
if opts.has_key('query_timeout'):
self.query_timeout = int(opts.pop('query_timeout'))
if opts.has_key('text_as_varchar'):
self.text_as_varchar = bool(int(opts.pop('text_as_varchar')))
if opts.has_key('use_scope_identity'):
self.use_scope_identity = bool(int(opts.pop('use_scope_identity')))
return self.make_connect_string(opts)
def create_execution_context(self, *args, **kwargs):
return MSSQLExecutionContext(self, *args, **kwargs)
def type_descriptor(self, typeobj):
newobj = sqltypes.adapt_type(typeobj, self.colspecs)
# Some types need to know about the dialect
if isinstance(newobj, (MSText, MSNVarchar)):
newobj.dialect = self
return newobj
def last_inserted_ids(self):
return self.context.last_inserted_ids
# this is only implemented in the dbapi-specific subclasses
def supports_sane_rowcount(self):
raise NotImplementedError()
def compiler(self, statement, bindparams, **kwargs):
return MSSQLCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
return MSSQLSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
return MSSQLSchemaDropper(self, *args, **kwargs)
def defaultrunner(self, connection, **kwargs):
return MSSQLDefaultRunner(connection, **kwargs)
def preparer(self):
return MSSQLIdentifierPreparer(self)
def get_default_schema_name(self, connection):
return self.schema_name
def set_default_schema_name(self, schema_name):
self.schema_name = schema_name
def last_inserted_ids(self):
return self.context.last_inserted_ids
def do_execute(self, cursor, statement, params, **kwargs):
if params == {}:
params = ()
super(MSSQLDialect, self).do_execute(cursor, statement, params, **kwargs)
def _execute(self, c, statement, parameters):
try:
if parameters == {}:
parameters = ()
c.execute(statement, parameters)
self.context.rowcount = c.rowcount
c.DBPROP_COMMITPRESERVE = "Y"
except Exception, e:
raise exceptions.SQLError(statement, parameters, e)
def raw_connection(self, connection):
"""Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
try:
# TODO: probably want to move this to individual dialect subclasses to
# save on the exception throw + simplify
return connection.connection.__dict__['_pymssqlCnx__cnx']
except:
return connection.connection.adoConn
def uppercase_table(self, t):
# convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive
t.name = t.name.upper()
if t.schema:
t.schema = t.schema.upper()
for c in t.columns:
c.name = c.name.upper()
return t
def has_table(self, connection, tablename, schema=None):
import sqlalchemy.databases.information_schema as ischema
current_schema = schema or self.get_default_schema_name(connection)
columns = self.uppercase_table(ischema.columns)
s = sql.select([columns],
current_schema
and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
or columns.c.table_name==tablename,
)
c = connection.execute(s)
row = c.fetchone()
return row is not None
def reflecttable(self, connection, table):
import sqlalchemy.databases.information_schema as ischema
# Get base columns
if table.schema is not None:
current_schema = table.schema
else:
current_schema = self.get_default_schema_name(connection)
columns = self.uppercase_table(ischema.columns)
s = sql.select([columns],
current_schema
and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema)
or columns.c.table_name==table.name,
order_by=[columns.c.ordinal_position])
c = connection.execute(s)
found_table = False
while True:
row = c.fetchone()
if row is None:
break
found_table = True
(name, type, nullable, charlen, numericprec, numericscale, default) = (
row[columns.c.column_name],
row[columns.c.data_type],
row[columns.c.is_nullable] == 'YES',
row[columns.c.character_maximum_length],
row[columns.c.numeric_precision],
row[columns.c.numeric_scale],
row[columns.c.column_default]
)
args = []
for a in (charlen, numericprec, numericscale):
if a is not None:
args.append(a)
coltype = self.ischema_names[type]
if coltype == MSString and charlen == -1:
coltype = MSText()
else:
if coltype == MSNVarchar and charlen == -1:
charlen = None
coltype = coltype(*args)
colargs= []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
if not found_table:
raise exceptions.NoSuchTableError(table.name)
# We also run an sp_columns to check for identity columns:
cursor = connection.execute("sp_columns " + self.preparer().format_table(table))
ic = None
while True:
row = cursor.fetchone()
if row is None:
break
col_name, type_name = row[3], row[5]
if type_name.endswith("identity"):
ic = table.c[col_name]
# setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
ic.sequence = schema.Sequence(ic.name + '_identity')
# MSSQL: only one identity per table allowed
cursor.close()
break
if not ic is None:
try:
cursor = connection.execute("select ident_seed(?), ident_incr(?)", table.fullname, table.fullname)
row = cursor.fetchone()
cursor.close()
if not row is None:
ic.sequence.start=int(row[0])
ic.sequence.increment=int(row[1])
except:
# ignoring it, works just like before
pass
# Add constraints
RR = self.uppercase_table(ischema.ref_constraints) #information_schema.referential_constraints
TC = self.uppercase_table(ischema.constraints) #information_schema.table_constraints
C = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column
R = self.uppercase_table(ischema.pg_key_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column
# Primary key constraints
s = sql.select([C.c.column_name, TC.c.constraint_type], sql.and_(TC.c.constraint_name == C.c.constraint_name,
C.c.table_name == table.name))
c = connection.execute(s)
for row in c:
if 'PRIMARY' in row[TC.c.constraint_type.name]:
table.primary_key.add(table.c[row[0]])
# Foreign key constraints
s = sql.select([C.c.column_name,
R.c.table_schema, R.c.table_name, R.c.column_name,
RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule],
sql.and_(C.c.table_name == table.name,
C.c.table_schema == current_schema,
C.c.constraint_name == RR.c.constraint_name,
R.c.constraint_name == RR.c.unique_constraint_name,
C.c.ordinal_position == R.c.ordinal_position
),
order_by = [RR.c.constraint_name, R.c.ordinal_position])
rows = connection.execute(s).fetchall()
# group rows by constraint ID, to handle multi-column FKs
fknm, scols, rcols = (None, [], [])
for r in rows:
scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
if rfknm != fknm:
if fknm:
table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
fknm, scols, rcols = (rfknm, [], [])
if (not scol in scols): scols.append(scol)
if (not (rschema, rtbl, rcol) in rcols): rcols.append((rschema, rtbl, rcol))
if fknm and scols:
table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
class MSSQLDialect_pymssql(MSSQLDialect):
def import_dbapi(cls):
import pymssql as module
# pymmsql doesn't have a Binary method. we use string
# TODO: monkeypatching here is less than ideal
module.Binary = lambda st: str(st)
return module
import_dbapi = classmethod(import_dbapi)
colspecs = MSSQLDialect.colspecs.copy()
colspecs[sqltypes.Date] = MSDate_pymssql
ischema_names = MSSQLDialect.ischema_names.copy()
ischema_names['smalldatetime'] = MSDate_pymssql
def __init__(self, **params):
super(MSSQLDialect_pymssql, self).__init__(**params)
self.use_scope_identity = True
def supports_sane_rowcount(self):
return False
def max_identifier_length(self):
return 30
def do_rollback(self, connection):
# pymssql throws an error on repeated rollbacks. Ignore it.
# TODO: this is normal behavior for most DBs. are we sure we want to ignore it ?
try:
connection.rollback()
except:
pass
def create_connect_args(self, url):
r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
if hasattr(self, 'query_timeout'):
self.dbapi._mssql.set_query_timeout(self.query_timeout)
return r
def make_connect_string(self, keys):
if keys.get('port'):
# pymssql expects port as host:port, not a separate arg
keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
del keys['port']
return [[], keys]
def is_disconnect(self, e):
return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
## This code is leftover from the initial implementation, for reference
## def do_begin(self, connection):
## """implementations might want to put logic here for turning autocommit on/off, etc."""
## pass
## def do_rollback(self, connection):
## """implementations might want to put logic here for turning autocommit on/off, etc."""
## try:
## # connection.rollback() for pymmsql failed sometimes--the begin tran doesn't show up
## # this is a workaround that seems to be handle it.
## r = self.raw_connection(connection)
## r.query("if @@trancount > 0 rollback tran")
## r.fetch_array()
## r.query("begin tran")
## r.fetch_array()
## except:
## pass
## def do_commit(self, connection):
## """implementations might want to put logic here for turning autocommit on/off, etc.
## do_commit is set for pymmsql connections--ADO seems to handle transactions without any issue
## """
## # ADO Uses Implicit Transactions.
## # This is very pymssql specific. We use this instead of its commit, because it hangs on failed rollbacks.
## # By using the "if" we don't assume an open transaction--much better.
## r = self.raw_connection(connection)
## r.query("if @@trancount > 0 commit tran")
## r.fetch_array()
## r.query("begin tran")
## r.fetch_array()
class MSSQLDialect_pyodbc(MSSQLDialect):
def __init__(self, **params):
super(MSSQLDialect_pyodbc, self).__init__(**params)
# whether use_scope_identity will work depends on the version of pyodbc
try:
import pyodbc
self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset')
except:
pass
def import_dbapi(cls):
import pyodbc as module
return module
import_dbapi = classmethod(import_dbapi)
colspecs = MSSQLDialect.colspecs.copy()
colspecs[sqltypes.Unicode] = AdoMSNVarchar
colspecs[sqltypes.Date] = MSDate_pyodbc
colspecs[sqltypes.DateTime] = MSDateTime_pyodbc
ischema_names = MSSQLDialect.ischema_names.copy()
ischema_names['nvarchar'] = AdoMSNVarchar
ischema_names['smalldatetime'] = MSDate_pyodbc
ischema_names['datetime'] = MSDateTime_pyodbc
def supports_sane_rowcount(self):
return False
def supports_unicode_statements(self):
"""indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
# PyODBC unicode is broken on UCS-4 builds
return sys.maxunicode == 65535
def make_connect_string(self, keys):
if 'dsn' in keys:
connectors = ['dsn=%s' % keys['dsn']]
else:
connectors = ["Driver={SQL Server}"]
if 'port' in keys:
connectors.append('Server=%s,%d' % (keys.get('host'), keys.get('port')))
else:
connectors.append('Server=%s' % keys.get('host'))
connectors.append("Database=%s" % keys.get("database"))
user = keys.get("user")
if user:
connectors.append("UID=%s" % user)
connectors.append("PWD=%s" % keys.get("password", ""))
else:
connectors.append ("TrustedConnection=Yes")
return [[";".join (connectors)], {}]
def is_disconnect(self, e):
return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e)
def create_execution_context(self, *args, **kwargs):
return MSSQLExecutionContext_pyodbc(self, *args, **kwargs)
def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
super(MSSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
if context and context.HASIDENT and (not context.IINSERT) and context.dialect.use_scope_identity:
import pyodbc
# fetch the last inserted id from the manipulated statement (pre_exec).
try:
row = cursor.fetchone()
except pyodbc.Error, e:
# if nocount OFF fetchone throws an exception and we have to jump over
# the rowcount to the resultset
cursor.nextset()
row = cursor.fetchone()
context._last_inserted_ids = [int(row[0])]
class MSSQLDialect_adodbapi(MSSQLDialect):
def import_dbapi(cls):
import adodbapi as module
return module
import_dbapi = classmethod(import_dbapi)
colspecs = MSSQLDialect.colspecs.copy()
colspecs[sqltypes.Unicode] = AdoMSNVarchar
colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
ischema_names = MSSQLDialect.ischema_names.copy()
ischema_names['nvarchar'] = AdoMSNVarchar
ischema_names['datetime'] = MSDateTime_adodbapi
def supports_sane_rowcount(self):
return True
def supports_unicode_statements(self):
"""indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
return True
def make_connect_string(self, keys):
connectors = ["Provider=SQLOLEDB"]
if 'port' in keys:
connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
else:
connectors.append ("Data Source=%s" % keys.get("host"))
connectors.append ("Initial Catalog=%s" % keys.get("database"))
user = keys.get("user")
if user:
connectors.append("User Id=%s" % user)
connectors.append("Password=%s" % keys.get("password", ""))
else:
connectors.append("Integrated Security=SSPI")
return [[";".join (connectors)], {}]
def is_disconnect(self, e):
return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
dialect_mapping = {
'pymssql': MSSQLDialect_pymssql,
'pyodbc': MSSQLDialect_pyodbc,
'adodbapi': MSSQLDialect_adodbapi
}
class MSSQLCompiler(ansisql.ANSICompiler):
def __init__(self, dialect, statement, parameters, **kwargs):
super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
self.tablealiases = {}
def visit_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
s = select.distinct and "DISTINCT " or ""
if select.limit:
s += "TOP %s " % (select.limit,)
if select.offset:
raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
def limit_clause(self, select):
# Limit in mssql is after the select keyword
return ""
def visit_table(self, table):
# alias schema-qualified tables
if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table):
alias = table.alias()
self.tablealiases[table] = alias
self.traverse(alias)
self.froms[('alias', table)] = self.froms[table]
for c in alias.c:
self.traverse(c)
self.traverse(alias.oid_column)
self.tablealiases[alias] = self.froms[table]
self.froms[table] = self.froms[alias]
else:
super(MSSQLCompiler, self).visit_table(table)
def visit_alias(self, alias):
# translate for schema-qualified table aliases
if self.froms.has_key(('alias', alias.original)):
self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name
self.strings[alias] = ""
else:
super(MSSQLCompiler, self).visit_alias(alias)
def visit_column(self, column):
# translate for schema-qualified table aliases
super(MSSQLCompiler, self).visit_column(column)
if column.table is not None and self.tablealiases.has_key(column.table):
self.strings[column] = \
self.strings[self.tablealiases[column.table].corresponding_column(column)]
def visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=':
binary.left, binary.right = binary.right, binary.left
super(MSSQLCompiler, self).visit_binary(binary)
def visit_select(self, select):
# label function calls, so they return a name in cursor.description
for i,c in enumerate(select._raw_columns):
if isinstance(c, sql._Function):
select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])
super(MSSQLCompiler, self).visit_select(select)
function_rewrites = {'current_date': 'getdate',
'length': 'len',
}
def visit_function(self, func):
func.name = self.function_rewrites.get(func.name, func.name)
super(MSSQLCompiler, self).visit_function(func)
def for_update_clause(self, select):
# "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
return ''
def order_by_clause(self, select):
order_by = self.get_str(select.order_by_clause)
# MSSQL only allows ORDER BY in subqueries if there is a LIMIT
if order_by and (not select.is_subquery or select.limit):
return " ORDER BY " + order_by
else:
return ""
class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
column.autoincrement and isinstance(column.type, sqltypes.Integer) and not column.foreign_key:
if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
column.sequence = schema.Sequence(column.name + '_seq')
if not column.nullable:
colspec += " NOT NULL"
if hasattr(column, 'sequence'):
column.table.has_sequence = column
colspec += " IDENTITY(%s,%s)" % (column.sequence.start or 1, column.sequence.increment or 1)
else:
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
return colspec
class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX %s.%s" % (
self.preparer.quote_identifier(index.table.name),
self.preparer.quote_identifier(index.name)))
self.execute()
class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
# TODO: does ms-sql have standalone sequences ?
pass
class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def __init__(self, dialect):
super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
def _escape_identifier(self, value):
#TODO: determin MSSQL's escapeing rules
return value
def _fold_identifier_case(self, value):
#TODO: determin MSSQL's case folding rules
return value
dialect = MSSQLDialect
|