#!/usr/bin/env python # # odb.py # # Object Database Api # # Written by David Jeske , 2001/07. # Inspired by eGroups' sqldb.py originally written by Scott Hassan circa 1998. # # Copyright (C) 2001, by David Jeske and Neotonic # # Goals: # - a simple object-like interface to database data # - database independent (someday) # - relational-style "rigid schema definition" # - object style easy-access # # Example: # # import odb # # # define table # class AgentsTable(odb.Table): # def _defineRows(self): # self.d_addColumn("agent_id",kInteger,None,primarykey = 1,autoincrement = 1) # self.d_addColumn("login",kVarString,200,notnull=1) # self.d_addColumn("ticket_count",kIncInteger,None) # # if __name__ == "__main__": # # open database # ndb = MySQLdb.connect(host = 'localhost', # user='username', # passwd = 'password', # db='testdb') # db = Database(ndb) # tbl = AgentsTable(db,"agents") # # # create row # agent_row = tbl.newRow() # agent_row.login = "foo" # agent_row.save() # # # fetch row (must use primary key) # try: # get_row = tbl.fetchRow( ('agent_id', agent_row.agent_id) ) # except odb.eNoMatchingRows: # print "this is bad, we should have found the row" # # # fetch rows (can return empty list) # list_rows = tbl.fetchRows( ('login', "foo") ) # import string import sys, zlib from log import * import handle_error eNoSuchColumn = "odb.eNoSuchColumn" eNonUniqueMatchSpec = "odb.eNonUniqueMatchSpec" eNoMatchingRows = "odb.eNoMatchingRows" eInternalError = "odb.eInternalError" eInvalidMatchSpec = "odb.eInvalidMatchSpec" eInvalidData = "odb.eInvalidData" eUnsavedObjectLost = "odb.eUnsavedObjectLost" eDuplicateKey = "odb.eDuplicateKey" ##################################### # COLUMN TYPES ################ ###################### # typename ####################### size data means: # # # kInteger = "kInteger" # - kFixedString = "kFixedString" # size kVarString = "kVarString" # maxsize kBigString = "kBigString" # - kIncInteger = "kIncInteger" # - kDateTime = "kDateTime" kTimeStamp = "kTimeStamp" kReal = "kReal" DEBUG = 0 ############## # Database # # this will ultimately turn into a mostly abstract base class for # the DB adaptors for different database types.... # class Database: def __init__(self, db, debug=0): self._tables = {} self.db = db self._cursor = None self.compression_enabled = 0 self.debug = debug self.SQLError = None self.__defaultRowClass = self.defaultRowClass() self.__defaultRowListClass = self.defaultRowListClass() def defaultCursor(self): if self._cursor is None: self._cursor = self.db.cursor() return self._cursor def escape(self,str): raise "Unimplemented Error" def getDefaultRowClass(self): return self.__defaultRowClass def setDefaultRowClass(self, clss): self.__defaultRowClass = clss def getDefaultRowListClass(self): return self.__defaultRowListClass def setDefaultRowListClass(self, clss): self.__defaultRowListClass = clss def defaultRowClass(self): return Row def defaultRowListClass(self): # base type is list... return list def addTable(self, attrname, tblname, tblclass, rowClass = None, check = 0, create = 0, rowListClass = None): tbl = tblclass(self, tblname, rowClass=rowClass, check=check, create=create, rowListClass=rowListClass) self._tables[attrname] = tbl return tbl def close(self): for name, tbl in self._tables.items(): tbl.db = None self._tables = {} if self.db is not None: self.db.close() self.db = None def __getattr__(self, key): if key == "_tables": raise AttributeError, "odb.Database: not initialized properly, self._tables does not exist" try: table_dict = getattr(self,"_tables") return table_dict[key] except KeyError: raise AttributeError, "odb.Database: unknown attribute %s" % (key) def beginTransaction(self, cursor=None): if cursor is None: cursor = self.defaultCursor() dlog(DEV_UPDATE,"begin") cursor.execute("begin") def commitTransaction(self, cursor=None): if cursor is None: cursor = self.defaultCursor() dlog(DEV_UPDATE,"commit") cursor.execute("commit") def rollbackTransaction(self, cursor=None): if cursor is None: cursor = self.defaultCursor() dlog(DEV_UPDATE,"rollback") cursor.execute("rollback") ## ## schema creation code ## def createTables(self): tables = self.listTables() for attrname, tbl in self._tables.items(): tblname = tbl.getTableName() if tblname not in tables: print "table %s does not exist" % tblname tbl.createTable() else: invalidAppCols, invalidDBCols = tbl.checkTable() ## self.alterTableToMatch(tbl) def createIndices(self): indices = self.listIndices() for attrname, tbl in self._tables.items(): for indexName, (columns, unique) in tbl.getIndices().items(): if indexName in indices: continue tbl.createIndex(columns, indexName=indexName, unique=unique) def synchronizeSchema(self): tables = self.listTables() for attrname, tbl in self._tables.items(): tblname = tbl.getTableName() self.alterTableToMatch(tbl) def listTables(self, cursor=None): raise "Unimplemented Error" def listFieldsDict(self, table_name, cursor=None): raise "Unimplemented Error" def listFields(self, table_name, cursor=None): columns = self.listFieldsDict(table_name, cursor=cursor) return columns.keys() ########################################## # Table # class Table: def subclassinit(self): pass def __init__(self,database,table_name, rowClass = None, check = 0, create = 0, rowListClass = None): self.db = database self.__table_name = table_name if rowClass: self.__defaultRowClass = rowClass else: self.__defaultRowClass = database.getDefaultRowClass() if rowListClass: self.__defaultRowListClass = rowListClass else: self.__defaultRowListClass = database.getDefaultRowListClass() # get this stuff ready! self.__column_list = [] self.__vcolumn_list = [] self.__columns_locked = 0 self.__has_value_column = 0 self.__indices = {} # this will be used during init... self.__col_def_hash = None self.__vcol_def_hash = None self.__primary_key_list = None self.__relations_by_table = {} # ask the subclass to def his rows self._defineRows() # get ready to run! self.__lockColumnsAndInit() self.subclassinit() if create: self.createTable() if check: self.checkTable() def _colTypeToSQLType(self, colname, coltype, options): if coltype == kInteger: coltype = "integer" elif coltype == kFixedString: sz = options.get('size', None) if sz is None: coltype = 'char' else: coltype = "char(%s)" % sz elif coltype == kVarString: sz = options.get('size', None) if sz is None: coltype = 'varchar' else: coltype = "varchar(%s)" % sz elif coltype == kBigString: coltype = "text" elif coltype == kIncInteger: coltype = "integer" elif coltype == kDateTime: coltype = "datetime" elif coltype == kTimeStamp: coltype = "timestamp" elif coltype == kReal: coltype = "real" coldef = "%s %s" % (colname, coltype) if options.get('notnull', 0): coldef = coldef + " NOT NULL" if options.get('autoincrement', 0): coldef = coldef + " AUTO_INCREMENT" if options.get('unique', 0): coldef = coldef + " UNIQUE" # if options.get('primarykey', 0): coldef = coldef + " primary key" if options.get('default', None) is not None: coldef = coldef + " DEFAULT %s" % options.get('default') return coldef def getTableName(self): return self.__table_name def setTableName(self, tablename): self.__table_name = tablename def getIndices(self): return self.__indices def _createTableSQL(self): defs = [] for colname, coltype, options in self.__column_list: defs.append(self._colTypeToSQLType(colname, coltype, options)) defs = string.join(defs, ", ") primarykeys = self.getPrimaryKeyList() primarykey_str = "" if primarykeys: primarykey_str = ", PRIMARY KEY (" + string.join(primarykeys, ",") + ")" sql = "create table %s (%s %s)" % (self.__table_name, defs, primarykey_str) return sql def createTable(self, cursor=None): if cursor is None: cursor = self.db.defaultCursor() sql = self._createTableSQL() print "CREATING TABLE:", sql cursor.execute(sql) def dropTable(self, cursor=None): if cursor is None: cursor = self.db.defaultCursor() try: cursor.execute("drop table %s" % self.__table_name) # clean out the table except self.SQLError, reason: pass def renameTable(self, newTableName, cursor=None): if cursor is None: cursor = self.db.defaultCursor() try: cursor.execute("rename table %s to %s" % (self.__table_name, newTableName)) except sel.SQLError, reason: pass self.setTableName(newTableName) def getTableColumnsFromDB(self): return self.db.listFieldsDict(self.__table_name) def checkTable(self, warnflag=1): invalidDBCols = {} invalidAppCols = {} dbcolumns = self.getTableColumnsFromDB() for coldef in self.__column_list: colname = coldef[0] dbcoldef = dbcolumns.get(colname, None) if dbcoldef is None: invalidAppCols[colname] = 1 for colname, row in dbcolumns.items(): coldef = self.__col_def_hash.get(colname, None) if coldef is None: invalidDBCols[colname] = 1 if warnflag == 1: if invalidDBCols: print "----- WARNING ------------------------------------------" print " There are columns defined in the database schema that do" print " not match the application's schema." print " columns:", invalidDBCols.keys() print "--------------------------------------------------------" if invalidAppCols: print "----- WARNING ------------------------------------------" print " There are new columns defined in the application schema" print " that do not match the database's schema." print " columns:", invalidAppCols.keys() print "--------------------------------------------------------" return invalidAppCols, invalidDBCols def alterTableToMatch(self): raise "Unimplemented Error!" def addIndex(self, columns, indexName=None, unique=0): if indexName is None: indexName = self.getTableName() + "_index_" + string.join(columns, "_") self.__indices[indexName] = (columns, unique) def createIndex(self, columns, indexName=None, unique=0, cursor=None): if cursor is None: cursor = self.db.defaultCursor() cols = string.join(columns, ",") if indexName is None: indexName = self.getTableName() + "_index_" + string.join(columns, "_") uniquesql = "" if unique: uniquesql = " unique" sql = "create %s index %s on %s (%s)" % (uniquesql, indexName, self.getTableName(), cols) warn("creating index", sql) cursor.execute(sql) ## Column Definition def getColumnDef(self,column_name): try: return self.__col_def_hash[column_name] except KeyError: try: return self.__vcol_def_hash[column_name] except KeyError: raise eNoSuchColumn, "no column (%s) on table %s" % (column_name,self.__table_name) def getColumnList(self): return self.__column_list + self.__vcolumn_list def getAppColumnList(self): return self.__column_list def databaseSizeForData_ColumnName_(self,data,col_name): try: col_def = self.__col_def_hash[col_name] except KeyError: try: col_def = self.__vcol_def_hash[col_name] except KeyError: raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name) c_name,c_type,c_options = col_def if c_type == kBigString: if c_options.get("compress_ok",0) and self.db.compression_enabled: z_size = len(zlib.compress(data,9)) r_size = len(data) if z_size < r_size: return z_size else: return r_size else: return len(data) else: # really simplistic database size computation: try: a = data[0] return len(data) except: return 4 def columnType(self, col_name): try: col_def = self.__col_def_hash[col_name] except KeyError: try: col_def = self.__vcol_def_hash[col_name] except KeyError: raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name) c_name,c_type,c_options = col_def return c_type def convertDataForColumn(self,data,col_name): try: col_def = self.__col_def_hash[col_name] except KeyError: try: col_def = self.__vcol_def_hash[col_name] except KeyError: raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name) c_name,c_type,c_options = col_def if c_type == kIncInteger: raise eInvalidData, "invalid operation for column (%s:%s) on table (%s)" % (col_name,c_type,self.__table_name) if c_type == kInteger: try: if data is None: data = 0 else: return long(data) except (ValueError,TypeError): raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data),col_name,c_type,self.__table_name) elif c_type == kReal: try: if data is None: data = 0.0 else: return float(data) except (ValueError,TypeError): raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data), col_name,c_type,self.__table_name) else: if type(data) == type(long(0)): return "%d" % data else: return str(data) def getPrimaryKeyList(self): return self.__primary_key_list def hasValueColumn(self): return self.__has_value_column def hasColumn(self,name): return self.__col_def_hash.has_key(name) def hasVColumn(self,name): return self.__vcol_def_hash.has_key(name) def _defineRows(self): raise "can't instantiate base odb.Table type, make a subclass and override _defineRows()" def __lockColumnsAndInit(self): # add a 'odb_value column' before we lockdown the table def if self.__has_value_column: self.d_addColumn("odb_value",kBigText,default='') self.__columns_locked = 1 # walk column list and make lookup hashes, primary_key_list, etc.. primary_key_list = [] col_def_hash = {} for a_col in self.__column_list: name,type,options = a_col col_def_hash[name] = a_col if options.has_key('primarykey'): primary_key_list.append(name) self.__col_def_hash = col_def_hash self.__primary_key_list = primary_key_list # setup the value columns! if (not self.__has_value_column) and (len(self.__vcolumn_list) > 0): raise "can't define vcolumns on table without ValueColumn, call d_addValueColumn() in your _defineRows()" vcol_def_hash = {} for a_col in self.__vcolumn_list: name,type,size_data,options = a_col vcol_def_hash[name] = a_col self.__vcol_def_hash = vcol_def_hash def __checkColumnLock(self): if self.__columns_locked: raise "can't change column definitions outside of subclass' _defineRows() method!" # table definition methods, these are only available while inside the # subclass's _defineRows method # # Ex: # # import odb # class MyTable(odb.Table): # def _defineRows(self): # self.d_addColumn("id",kInteger,primarykey = 1,autoincrement = 1) # self.d_addColumn("name",kVarString,120) # self.d_addColumn("type",kInteger, # enum_values = { 0 : "alive", 1 : "dead" } def d_addColumn(self,col_name,ctype,size=None,primarykey = 0, notnull = 0,indexed=0, default=None,unique=0,autoincrement=0,safeupdate=0, enum_values = None, no_export = 0, relations=None,compress_ok=0,int_date=0): self.__checkColumnLock() options = {} options['default'] = default if primarykey: options['primarykey'] = primarykey if unique: options['unique'] = unique if indexed: options['indexed'] = indexed self.addIndex((col_name,)) if safeupdate: options['safeupdate'] = safeupdate if autoincrement: options['autoincrement'] = autoincrement if notnull: options['notnull'] = notnull if size: options['size'] = size if no_export: options['no_export'] = no_export if int_date: if ctype != kInteger: raise eInvalidData, "can't flag columns int_date unless they are kInteger" else: options['int_date'] = int_date if enum_values: options['enum_values'] = enum_values inv_enum_values = {} for k,v in enum_values.items(): if inv_enum_values.has_key(v): raise eInvalidData, "enum_values paramater must be a 1 to 1 mapping for Table(%s)" % self.__table_name else: inv_enum_values[v] = k options['inv_enum_values'] = inv_enum_values if relations: options['relations'] = relations for a_relation in relations: table, foreign_column_name = a_relation if self.__relations_by_table.has_key(table): raise eInvalidData, "multiple relations for the same foreign table are not yet supported" self.__relations_by_table[table] = (col_name,foreign_column_name) if compress_ok: if ctype == kBigString: options['compress_ok'] = 1 else: raise eInvalidData, "only kBigString fields can be compress_ok=1" self.__column_list.append( (col_name,ctype,options) ) def d_addValueColumn(self): self.__checkColumnLock() self.__has_value_column = 1 def d_addVColumn(self,col_name,type,size=None,default=None): self.__checkColumnLock() if (not self.__has_value_column): raise "can't define VColumns on table without ValueColumn, call d_addValueColumn() first" options = {} if default: options['default'] = default if size: options['size'] = size self.__vcolumn_list.append( (col_name,type,options) ) ##################### # _checkColMatchSpec(col_match_spec,should_match_unique_row = 0) # # raise an error if the col_match_spec contains invalid columns, or # (in the case of should_match_unique_row) if it does not fully specify # a unique row. # # NOTE: we don't currently support where clauses with value column fields! # def _fixColMatchSpec(self,col_match_spec, should_match_unique_row = 0): if type(col_match_spec) == type([]): if type(col_match_spec[0]) != type((0,)): raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)" elif type(col_match_spec) == type((0,)): col_match_spec = [ col_match_spec ] elif type(col_match_spec) == type(None): if should_match_unique_row: raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec else: return None else: raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)" if should_match_unique_row: unique_column_lists = [] # first the primary key list my_primary_key_list = [] for a_key in self.__primary_key_list: my_primary_key_list.append(a_key) # then other unique keys for a_col in self.__column_list: col_name,a_type,options = a_col if options.has_key('unique'): unique_column_lists.append( (col_name, [col_name]) ) unique_column_lists.append( ('primary_key', my_primary_key_list) ) new_col_match_spec = [] for a_col in col_match_spec: name,val = a_col # newname = string.lower(name) # what is this doing?? - jeske newname = name if not self.__col_def_hash.has_key(newname): raise eNoSuchColumn, "no such column in match spec: '%s'" % newname new_col_match_spec.append( (newname,val) ) if should_match_unique_row: for name,a_list in unique_column_lists: try: a_list.remove(newname) except ValueError: # it's okay if they specify too many columns! pass if should_match_unique_row: for name,a_list in unique_column_lists: if len(a_list) == 0: # we matched at least one unique colum spec! # log("using unique column (%s) for query %s" % (name,col_match_spec)) return new_col_match_spec raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec return new_col_match_spec def __buildWhereClause (self, col_match_spec,other_clauses = None): sql_where_list = [] if not col_match_spec is None: for m_col in col_match_spec: m_col_name,m_col_val = m_col c_name,c_type,c_options = self.__col_def_hash[m_col_name] if c_type in (kIncInteger, kInteger): try: m_col_val_long = long(m_col_val) except ValueError: raise ValueError, "invalid literal for long(%s) in table %s" % (repr(m_col_val),self.__table_name) sql_where_list.append("%s = %d" % (c_name, m_col_val_long)) elif c_type == kReal: try: m_col_val_float = float(m_col_val) except ValueError: raise ValueError, "invalid literal for float(%s) is table %s" % (repr(m_col_val), self.__table_name) sql_where_list.append("%s = %s" % (c_name, m_col_val_float)) else: sql_where_list.append("%s = '%s'" % (c_name, self.db.escape(m_col_val))) if other_clauses is None: pass elif type(other_clauses) == type(""): sql_where_list = sql_where_list + [other_clauses] elif type(other_clauses) == type([]): sql_where_list = sql_where_list + other_clauses else: raise eInvalidData, "unknown type of extra where clause: %s" % repr(other_clauses) return sql_where_list def __fetchRows(self,col_match_spec,cursor = None, where = None, order_by = None, limit_to = None, skip_to = None, join = None): if cursor is None: cursor = self.db.defaultCursor() # build column list sql_columns = [] for name,t,options in self.__column_list: sql_columns.append(name) # build join information joined_cols = [] joined_cols_hash = {} join_clauses = [] if not join is None: for a_table,retrieve_foreign_cols in join: try: my_col,foreign_col = self.__relations_by_table[a_table] for a_col in retrieve_foreign_cols: full_col_name = "%s.%s" % (my_col,a_col) joined_cols_hash[full_col_name] = 1 joined_cols.append(full_col_name) sql_columns.append( full_col_name ) join_clauses.append(" left join %s as %s on %s=%s " % (a_table,my_col,my_col,foreign_col)) except KeyError: eInvalidJoinSpec, "can't find table %s in defined relations for %s" % (a_table,self.__table_name) # start buildling SQL sql = "select %s from %s" % (string.join(sql_columns,","), self.__table_name) # add join clause if join_clauses: sql = sql + string.join(join_clauses," ") # add where clause elements sql_where_list = self.__buildWhereClause (col_match_spec,where) if sql_where_list: sql = sql + " where %s" % (string.join(sql_where_list," and ")) # add order by clause if order_by: sql = sql + " order by %s " % string.join(order_by,",") # add limit if not limit_to is None: if not skip_to is None: # log("limit,skip = %s,%s" % (limit_to,skip_to)) if self.db.db.__module__ == "sqlite.main": sql = sql + " limit %s offset %s " % (limit_to,skip_to) else: sql = sql + " limit %s, %s" % (skip_to,limit_to) else: sql = sql + " limit %s" % limit_to else: if not skip_to is None: raise eInvalidData, "can't specify skip_to without limit_to in MySQL" dlog(DEV_SELECT,sql) cursor.execute(sql) # create defaultRowListClass instance... return_rows = self.__defaultRowListClass() # should do fetchmany! all_rows = cursor.fetchall() for a_row in all_rows: data_dict = {} col_num = 0 # for a_col in cursor.description: # (name,type_code,display_size,internal_size,precision,scale,null_ok) = a_col for name in sql_columns: if self.__col_def_hash.has_key(name) or joined_cols_hash.has_key(name): # only include declared columns! if self.__col_def_hash.has_key(name): c_name,c_type,c_options = self.__col_def_hash[name] if c_type == kBigString and c_options.get("compress_ok",0) and a_row[col_num]: try: a_col_data = zlib.decompress(a_row[col_num]) except zlib.error: a_col_data = a_row[col_num] data_dict[name] = a_col_data elif c_type == kInteger or c_type == kIncInteger: value = a_row[col_num] if not value is None: data_dict[name] = int(value) else: data_dict[name] = None else: data_dict[name] = a_row[col_num] else: data_dict[name] = a_row[col_num] col_num = col_num + 1 newrowobj = self.__defaultRowClass(self,data_dict,joined_cols = joined_cols) return_rows.append(newrowobj) return return_rows def __deleteRow(self,a_row,cursor = None): if cursor is None: cursor = self.db.defaultCursor() # build the where clause! match_spec = a_row.getPKMatchSpec() sql_where_list = self.__buildWhereClause (match_spec) sql = "delete from %s where %s" % (self.__table_name, string.join(sql_where_list," and ")) dlog(DEV_UPDATE,sql) cursor.execute(sql) def __updateRowList(self,a_row_list,cursor = None): if cursor is None: cursor = self.db.defaultCursor() for a_row in a_row_list: update_list = a_row.changedList() # build the set list! sql_set_list = [] for a_change in update_list: col_name,col_val,col_inc_val = a_change c_name,c_type,c_options = self.__col_def_hash[col_name] if c_type != kIncInteger and col_val is None: sql_set_list.append("%s = NULL" % c_name) elif c_type == kIncInteger and col_inc_val is None: sql_set_list.append("%s = 0" % c_name) else: if c_type == kInteger: sql_set_list.append("%s = %d" % (c_name, long(col_val))) elif c_type == kIncInteger: sql_set_list.append("%s = %s + %d" % (c_name,c_name,long(col_inc_val))) elif c_type == kBigString and c_options.get("compress_ok",0) and self.db.compression_enabled: compressed_data = zlib.compress(col_val,9) if len(compressed_data) < len(col_val): sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(compressed_data))) else: sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val))) elif c_type == kReal: sql_set_list.append("%s = %s" % (c_name,float(col_val))) else: sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val))) # build the where clause! match_spec = a_row.getPKMatchSpec() sql_where_list = self.__buildWhereClause (match_spec) if sql_set_list: sql = "update %s set %s where %s" % (self.__table_name, string.join(sql_set_list,","), string.join(sql_where_list," and ")) dlog(DEV_UPDATE,sql) try: cursor.execute(sql) except Exception, reason: if string.find(str(reason), "Duplicate entry") != -1: raise eDuplicateKey, reason raise Exception, reason a_row.markClean() def __insertRow(self,a_row_obj,cursor = None,replace=0): if cursor is None: cursor = self.db.defaultCursor() sql_col_list = [] sql_data_list = [] auto_increment_column_name = None for a_col in self.__column_list: name,type,options = a_col try: data = a_row_obj[name] sql_col_list.append(name) if data is None: sql_data_list.append("NULL") else: if type == kInteger or type == kIncInteger: sql_data_list.append("%d" % data) elif type == kBigString and options.get("compress_ok",0) and self.db.compression_enabled: compressed_data = zlib.compress(data,9) if len(compressed_data) < len(data): sql_data_list.append("'%s'" % self.db.escape(compressed_data)) else: sql_data_list.append("'%s'" % self.db.escape(data)) elif type == kReal: sql_data_list.append("%s" % data) else: sql_data_list.append("'%s'" % self.db.escape(data)) except KeyError: if options.has_key("autoincrement"): if auto_increment_column_name: raise eInternalError, "two autoincrement columns (%s,%s) in table (%s)" % (auto_increment_column_name, name,self.__table_name) else: auto_increment_column_name = name if replace: sql = "replace into %s (%s) values (%s)" % (self.__table_name, string.join(sql_col_list,","), string.join(sql_data_list,",")) else: sql = "insert into %s (%s) values (%s)" % (self.__table_name, string.join(sql_col_list,","), string.join(sql_data_list,",")) dlog(DEV_UPDATE,sql) try: cursor.execute(sql) except Exception, reason: # sys.stderr.write("errror in statement: " + sql + "\n") log("error in statement: " + sql + "\n") if string.find(str(reason), "Duplicate entry") != -1: raise eDuplicateKey, reason raise Exception, reason if auto_increment_column_name: if cursor.__module__ == "sqlite.main": a_row_obj[auto_increment_column_name] = cursor.lastrowid elif cursor.__module__ == "MySQLdb.cursors": a_row_obj[auto_increment_column_name] = cursor.insert_id() else: # fallback to acting like mysql a_row_obj[auto_increment_column_name] = cursor.insert_id() # ---------------------------------------------------- # Helper methods for Rows... # ---------------------------------------------------- ##################### # r_deleteRow(a_row_obj,cursor = None) # # normally this is called from within the Row "delete()" method # but you can call it yourself if you want # def r_deleteRow(self,a_row_obj, cursor = None): curs = cursor self.__deleteRow(a_row_obj, cursor = curs) ##################### # r_updateRow(a_row_obj,cursor = None) # # normally this is called from within the Row "save()" method # but you can call it yourself if you want # def r_updateRow(self,a_row_obj, cursor = None): curs = cursor self.__updateRowList([a_row_obj], cursor = curs) ##################### # InsertRow(a_row_obj,cursor = None) # # normally this is called from within the Row "save()" method # but you can call it yourself if you want # def r_insertRow(self,a_row_obj, cursor = None,replace=0): curs = cursor self.__insertRow(a_row_obj, cursor = curs,replace=replace) # ---------------------------------------------------- # Public Methods # ---------------------------------------------------- ##################### # deleteRow(col_match_spec) # # The col_match_spec paramaters must include all primary key columns. # # Ex: # a_row = tbl.fetchRow( ("order_id", 1) ) # a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] ) def deleteRow(self,col_match_spec, where=None): n_match_spec = self._fixColMatchSpec(col_match_spec) cursor = self.db.defaultCursor() # build sql where clause elements sql_where_list = self.__buildWhereClause (n_match_spec,where) if not sql_where_list: return sql = "delete from %s where %s" % (self.__table_name, string.join(sql_where_list," and ")) dlog(DEV_UPDATE,sql) cursor.execute(sql) ##################### # fetchRow(col_match_spec) # # The col_match_spec paramaters must include all primary key columns. # # Ex: # a_row = tbl.fetchRow( ("order_id", 1) ) # a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] ) def fetchRow(self, col_match_spec, cursor = None): n_match_spec = self._fixColMatchSpec(col_match_spec, should_match_unique_row = 1) rows = self.__fetchRows(n_match_spec, cursor = cursor) if len(rows) == 0: raise eNoMatchingRows, "no row matches %s" % repr(n_match_spec) if len(rows) > 1: raise eInternalError, "unique where clause shouldn't return > 1 row" return rows[0] ##################### # fetchRows(col_match_spec) # # Ex: # a_row_list = tbl.fetchRows( ("order_id", 1) ) # a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] ) def fetchRows(self, col_match_spec = None, cursor = None, where = None, order_by = None, limit_to = None, skip_to = None, join = None): n_match_spec = self._fixColMatchSpec(col_match_spec) return self.__fetchRows(n_match_spec, cursor = cursor, where = where, order_by = order_by, limit_to = limit_to, skip_to = skip_to, join = join) def fetchRowCount (self, col_match_spec = None, cursor = None, where = None): n_match_spec = self._fixColMatchSpec(col_match_spec) sql_where_list = self.__buildWhereClause (n_match_spec,where) sql = "select count(*) from %s" % self.__table_name if sql_where_list: sql = "%s where %s" % (sql,string.join(sql_where_list," and ")) if cursor is None: cursor = self.db.defaultCursor() dlog(DEV_SELECT,sql) cursor.execute(sql) try: count, = cursor.fetchone() except TypeError: count = 0 return count ##################### # fetchAllRows() # # Ex: # a_row_list = tbl.fetchRows( ("order_id", 1) ) # a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] ) def fetchAllRows(self): try: return self.__fetchRows([]) except eNoMatchingRows: # else return empty list... return self.__defaultRowListClass() def newRow(self,replace=0): row = self.__defaultRowClass(self,None,create=1,replace=replace) for (cname, ctype, opts) in self.__column_list: if opts['default'] is not None and ctype is not kIncInteger: row[cname] = opts['default'] return row class Row: __instance_data_locked = 0 def subclassinit(self): pass def __init__(self,_table,data_dict,create=0,joined_cols = None,replace=0): self._inside_getattr = 0 # stop recursive __getattr__ self._table = _table self._should_insert = create or replace self._should_replace = replace self._rowInactive = None self._joinedRows = [] self.__pk_match_spec = None self.__vcoldata = {} self.__inc_coldata = {} self.__joined_cols_dict = {} for a_col in joined_cols or []: self.__joined_cols_dict[a_col] = 1 if create: self.__coldata = {} else: if type(data_dict) != type({}): raise eInternalError, "rowdict instantiate with bad data_dict" self.__coldata = data_dict self.__unpackVColumn() self.markClean() self.subclassinit() self.__instance_data_locked = 1 def joinRowData(self,another_row): self._joinedRows.append(another_row) def getPKMatchSpec(self): return self.__pk_match_spec def markClean(self): self.__vcolchanged = 0 self.__colchanged_dict = {} for key in self.__inc_coldata.keys(): self.__coldata[key] = self.__coldata.get(key, 0) + self.__inc_coldata[key] self.__inc_coldata = {} if not self._should_insert: # rebuild primary column match spec new_match_spec = [] for col_name in self._table.getPrimaryKeyList(): try: rdata = self[col_name] except KeyError: raise eInternalError, "must have primary key data filled in to save %s:Row(col:%s)" % (self._table.getTableName(),col_name) new_match_spec.append( (col_name, rdata) ) self.__pk_match_spec = new_match_spec def __unpackVColumn(self): if self._table.hasValueColumn(): pass def __packVColumn(self): if self._table.hasValueColumn(): pass ## ----- utility stuff ---------------------------------- def __del__(self): # check for unsaved changes changed_list = self.changedList() if len(changed_list): info = "unsaved Row for table (%s) lost, call discard() to avoid this error. Lost changes: %s\n" % (self._table.getTableName(), repr(changed_list)[:256]) if 0: raise eUnsavedObjectLost, info else: sys.stderr.write(info) def __repr__(self): return "Row from (%s): %s" % (self._table.getTableName(),repr(self.__coldata) + repr(self.__vcoldata)) ## ---- class emulation -------------------------------- def __getattr__(self,key): if self._inside_getattr: raise AttributeError, "recursively called __getattr__ (%s,%s)" % (key,self._table.getTableName()) try: self._inside_getattr = 1 try: return self[key] except KeyError: if self._table.hasColumn(key) or self._table.hasVColumn(key): return None else: raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName()) finally: self._inside_getattr = 0 def __setattr__(self,key,val): if not self.__instance_data_locked: self.__dict__[key] = val else: my_dict = self.__dict__ if my_dict.has_key(key): my_dict[key] = val else: # try and put it into the rowdata try: self[key] = val except KeyError, reason: raise AttributeError, reason ## ---- dict emulation --------------------------------- def __getitem__(self,key): self.checkRowActive() try: c_type = self._table.columnType(key) except eNoSuchColumn, reason: # Ugh, this sucks, we can't determine the type for a joined # row, so we just default to kVarString and let the code below # determine if this is a joined column or not c_type = kVarString if c_type == kIncInteger: c_data = self.__coldata.get(key, 0) if c_data is None: c_data = 0 i_data = self.__inc_coldata.get(key, 0) if i_data is None: i_data = 0 return c_data + i_data try: return self.__coldata[key] except KeyError: try: return self.__vcoldata[key] except KeyError: for a_joined_row in self._joinedRows: try: return a_joined_row[key] except KeyError: pass raise KeyError, "unknown column %s in %s" % (key,self) def __setitem__(self,key,data): self.checkRowActive() try: newdata = self._table.convertDataForColumn(data,key) except eNoSuchColumn, reason: raise KeyError, reason if self._table.hasColumn(key): self.__coldata[key] = newdata self.__colchanged_dict[key] = 1 elif self._table.hasVColumn(key): self.__vcoldata[key] = newdata self.__vcolchanged = 1 else: for a_joined_row in self._joinedRows: try: a_joined_row[key] = data return except KeyError: pass raise KeyError, "unknown column name %s" % key def __delitem__(self,key,data): self.checkRowActive() if self.table.hasVColumn(key): del self.__vcoldata[key] else: for a_joined_row in self._joinedRows: try: del a_joined_row[key] return except KeyError: pass raise KeyError, "unknown column name %s" % key def copyFrom(self,source): for name,t,options in self._table.getColumnList(): if not options.has_key("autoincrement"): self[name] = source[name] # make sure that .keys(), and .items() come out in a nice order! def keys(self): self.checkRowActive() key_list = [] for name,t,options in self._table.getColumnList(): key_list.append(name) for name in self.__joined_cols_dict.keys(): key_list.append(name) for a_joined_row in self._joinedRows: key_list = key_list + a_joined_row.keys() return key_list def items(self): self.checkRowActive() item_list = [] for name,t,options in self._table.getColumnList(): item_list.append( (name,self[name]) ) for name in self.__joined_cols_dict.keys(): item_list.append( (name,self[name]) ) for a_joined_row in self._joinedRows: item_list = item_list + a_joined_row.items() return item_list def values(elf): self.checkRowActive() value_list = self.__coldata.values() + self.__vcoldata.values() for a_joined_row in self._joinedRows: value_list = value_list + a_joined_row.values() return value_list def __len__(self): self.checkRowActive() my_len = len(self.__coldata) + len(self.__vcoldata) for a_joined_row in self._joinedRows: my_len = my_len + len(a_joined_row) return my_len def has_key(self,key): self.checkRowActive() if self.__coldata.has_key(key) or self.__vcoldata.has_key(key): return 1 else: for a_joined_row in self._joinedRows: if a_joined_row.has_key(key): return 1 return 0 def get(self,key,default = None): self.checkRowActive() if self.__coldata.has_key(key): return self.__coldata[key] elif self.__vcoldata.has_key(key): return self.__vcoldata[key] else: for a_joined_row in self._joinedRows: try: return a_joined_row.get(key,default) except eNoSuchColumn: pass if self._table.hasColumn(key): return default raise eNoSuchColumn, "no such column %s" % key def inc(self,key,count=1): self.checkRowActive() if self._table.hasColumn(key): try: self.__inc_coldata[key] = self.__inc_coldata[key] + count except KeyError: self.__inc_coldata[key] = count self.__colchanged_dict[key] = 1 else: raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName()) ## ---------------------------------- ## real interface def fillDefaults(self): for field_def in self._table.fieldList(): name,type,size,options = field_def if options.has_key("default"): self[name] = options["default"] ############### # changedList() # # returns a list of tuples for the columns which have changed # # changedList() -> [ ('name', 'fred'), ('age', 20) ] def changedList(self): if self.__vcolchanged: self.__packVColumn() changed_list = [] for a_col in self.__colchanged_dict.keys(): changed_list.append( (a_col,self.get(a_col,None),self.__inc_coldata.get(a_col,None)) ) return changed_list def discard(self): self.__coldata = None self.__vcoldata = None self.__colchanged_dict = {} self.__vcolchanged = 0 def delete(self,cursor = None): self.checkRowActive() fromTable = self._table curs = cursor fromTable.r_deleteRow(self,cursor=curs) self._rowInactive = "deleted" def save(self,cursor = None): toTable = self._table self.checkRowActive() if self._should_insert: toTable.r_insertRow(self,replace=self._should_replace) self._should_insert = 0 self._should_replace = 0 self.markClean() # rebuild the primary key list else: curs = cursor toTable.r_updateRow(self,cursor = curs) # the table will mark us clean! # self.markClean() def checkRowActive(self): if self._rowInactive: raise eInvalidData, "row is inactive: %s" % self._rowInactive def databaseSizeForColumn(self,key): return self._table.databaseSizeForData_ColumnName_(self[key],key) if __name__ == "__main__": print "run odb_test.py"