# sqlrow.py
#
# Copyright 2005 Wichert Akkerman <wichert@wiggy.net>

class MetaParent(type):
	def __new__(cls, className, bases, d):
		instance=type.__new__(cls, className, bases, d)

		if not "dbc" in d:
			import sys
			for mod in [ instance.__module__, "__main__" ]:
				if hasattr(sys.modules[mod], "dbc"):
					instance.dbc=sys.modules[mod].dbc
					break


		for col in d.get("_columns", []):
			if hasattr(instance, col):
				continue
			setattr(instance, col, property(
				fset=lambda self,value,col=col: self._set(col, value),
				fget=lambda self,col=col: self._sqldata[col],
				doc="SQL column %s" % col))

		return instance


class SQLRow(object):
	"""SQL row.

	A SQL row represents the data stored in a single row in a SQL table
	(or view).
	"""

	__metaclass__	= MetaParent

	def __init__(self, **kwargs):
		self._sqldata={}
		self._changed=[]
		self._stored=False

		for col in self._columns:
			self._sqldata[col]=None

		for (key,value) in kwargs.items():
			setattr(self, key, value)

		# A simple comparison does not work since that also
		# compares the order
		if len(self._keys)!=len(self._changed):
			return

		for i in self._keys:
			if not i in self._changed:
				return

		self.retrieve()


	def _set(self, attr, value):
		if self._stored and attr in self._keys:
			raise NotImplementedError, "Changing keys from stored rows is not implemented"
		if self._sqldata[attr]==value:
			return
		if attr not in self._changed:
			self._changed.append(attr)
		self._sqldata[attr]=value


	def _genwhere(self):
		"""Generate data for a WHERE clause to identify this object.

		This method generates data which can be used to generate
		a WHERE clause to uniquely identify this object in a table. It
		returns a tuple containing a string with the SQL command and a 
		tuple with the data values. This can be fed to the execute
		method for a database connection using the format paramstyle.

		@return: (command,values) tuple
		"""

		cmd=" AND ".join([x+"=%s" for x in self._keys])
		values=[self._sqldata[x] for x in self._keys]
		if None in values:
			raise KeyError, "Not all keys set"

		return (cmd,tuple(values))


	def retrieve(self):
		(query,values)=self._genwhere()
		nonkeys=filter(lambda x,keys=self._keys: x not in keys, self._columns)

		c=self.dbc.execute("SELECT %s FROM %s WHERE %s;" % 
				(",".join(nonkeys), self._table, query), values, "format")

		try:
			data=c.fetchall()

			if not data:
				raise KeyError, "No matching row found"
			elif len(data)>1:
				raise KeyError, "multiple rows match key"

			self._changed=[]
			self._added=[]

			for (key,value) in zip(nonkeys, data[0]):
				self._sqldata[key]=value

			self._stored=True
		finally:
			c.close()
	

	def _sql_insert(self):
		keys=filter(lambda x,data=self._sqldata: data[x]!=None, self._columns)
		values=[self._sqldata[x] for x in keys]

		self.dbc.execute("INSERT INTO %s (%s) VALUES (%s);" %
				(self._table, ",".join(keys), ",".join(["%s"]*len(keys))),
				tuple(values), "format")
		self._stored=True


	def _sql_update(self):
		keys=filter(lambda x,data=self._sqldata: data[x]!=None, self._changed)
		if not keys:
			return
		values=tuple([self._sqldata[x] for x in keys])
		ucmd=[x+"=%s" for x in keys]
		(wquery,wvalues)=self._genwhere()

		self.dbc.execute("UPDATE %s SET %s WHERE %s;" %
				(self._table, ",".join([x+"=%s" for x in keys]), wquery),
				(values+wvalues), "format")


	def store(self):
		if not self._stored:
			self._sql_insert()
		else:
			self._sql_update()

		self._changed=[]


	# Compatibility layer for SQLObject-expecting code
	def __getitem__(self, attr):
		return self._sqldata[attr]
	def __setitem__(self, attr, value):
		self._set(key, value)
	def has_key(self, key):
		return self._sqldata.has_key(key)
	def commit(self):
		self.store()


