Factor out SQL retries
Test if sqlite is multithreading-safe and bail out if not. sqlite versions since at least 2008 are. But, as it still causes errors when 2 threads try to write to the same connection simultanously (We get a "cannot start transaction within a transaction" error), we protect writes with a per class, ie per-connection lock. Factor out the retrying to write when the database is locked. Signed-off-by: Sebastian Spaeth <Sebastian@SSpaeth.de> Signed-off-by: Nicolas Sebrecht <nicolas.s-dev@laposte.net>
This commit is contained in:
		 Sebastian Spaeth
					Sebastian Spaeth
				
			
				
					committed by
					
						 Nicolas Sebrecht
						Nicolas Sebrecht
					
				
			
			
				
	
			
			
			 Nicolas Sebrecht
						Nicolas Sebrecht
					
				
			
						parent
						
							af25c2779f
						
					
				
				
					commit
					0af9ef70a7
				
			| @@ -16,15 +16,29 @@ | ||||
| #    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301 USA | ||||
| import os.path | ||||
| import re | ||||
| from threading import Lock | ||||
| from LocalStatus import LocalStatusFolder, magicline | ||||
| try: | ||||
|     from pysqlite2 import dbapi2 as sqlite | ||||
|     import sqlite3 as sqlite | ||||
| except: | ||||
|     pass #fail only if needed later on, not on import | ||||
|  | ||||
| class LocalStatusSQLiteFolder(LocalStatusFolder): | ||||
|     """LocalStatus backend implemented with an SQLite database""" | ||||
|     #current version of the db format | ||||
|     """LocalStatus backend implemented with an SQLite database | ||||
|  | ||||
|     As python-sqlite currently does not allow to access the same sqlite | ||||
|     objects from various threads, we need to open get and close a db | ||||
|     connection and cursor for all operations. This is a big disadvantage | ||||
|     and we might want to investigate if we cannot hold an object open | ||||
|     for a thread somehow.""" | ||||
|     #though. According to sqlite docs, you need to commit() before | ||||
|     #the connection is closed or your changes will be lost!""" | ||||
|     #get db connection which autocommits | ||||
|     #connection = sqlite.connect(self.filename, isolation_level=None)         | ||||
|     #cursor = connection.cursor() | ||||
|     #return connection, cursor | ||||
|  | ||||
|     #current version of our db format | ||||
|     cur_version = 1 | ||||
|  | ||||
|     def __init__(self, root, name, repository, accountname, config): | ||||
| @@ -32,33 +46,68 @@ class LocalStatusSQLiteFolder(LocalStatusFolder): | ||||
|                                                       repository,  | ||||
|                                                       accountname, | ||||
|                                                       config)        | ||||
|         #Try to establish connection | ||||
|  | ||||
|         # dblock protects against concurrent writes in same connection | ||||
|         self._dblock = Lock() | ||||
|         #Try to establish connection, no need for threadsafety in __init__ | ||||
|         try: | ||||
|             self.connection = sqlite.connect(self.filename) | ||||
|             self.connection = sqlite.connect(self.filename, check_same_thread = False) | ||||
|         except NameError: | ||||
|             # sqlite import had failed | ||||
|             raise UserWarning('SQLite backend chosen, but no sqlite python ' | ||||
|                               'bindings available. Please install.') | ||||
|  | ||||
|         #Test if the db version is current enough and if the db is | ||||
|         #readable. | ||||
|         #Make sure sqlite is in multithreading SERIALIZE mode | ||||
|         assert sqlite.threadsafety == 1, 'Your sqlite is not multithreading safe.' | ||||
|  | ||||
|         #Test if db version is current enough and if db is readable. | ||||
|         try: | ||||
|             self.cursor = self.connection.cursor() | ||||
|             self.cursor.execute("SELECT value from metadata WHERE key='db_version'") | ||||
|             cursor = self.connection.execute("SELECT value from metadata WHERE key='db_version'") | ||||
|         except sqlite.DatabaseError: | ||||
|             #db file missing or corrupt, recreate it. | ||||
|             self.connection.close() | ||||
|             self.upgrade_db(0) | ||||
|         else: | ||||
|             # fetch db version and upgrade if needed | ||||
|             version = int(self.cursor.fetchone()[0]) | ||||
|             self.cursor.close() | ||||
|             version = int(cursor.fetchone()[0]) | ||||
|             if version < LocalStatusSQLiteFolder.cur_version: | ||||
|                 self.upgrade_db(version) | ||||
|             self.connection.close() | ||||
|  | ||||
|     def sql_write(self, sql, vars=None): | ||||
|         """execute some SQL retrying if the db was locked. | ||||
|  | ||||
|         :param sql: the SQL string passed to execute() :param args: the | ||||
|             variable values to `sql`. E.g. (1,2) or {uid:1, flags:'T'}. See | ||||
|             sqlite docs for possibilities. | ||||
|         :returns: the Cursor() or raises an Exception""" | ||||
|         success = False | ||||
|         while not success: | ||||
|             self._dblock.acquire() | ||||
|             try: | ||||
|                 if vars is None: | ||||
|                     cursor = self.connection.execute(sql) | ||||
|                 else: | ||||
|                     cursor = self.connection.execute(sql, vars) | ||||
|                 success = True | ||||
|                 self.connection.commit() | ||||
|             except sqlite.OperationalError, e: | ||||
|                 if e.args[0] == 'cannot commit - no transaction is active': | ||||
|                     pass | ||||
|                 elif e.args[0] == 'database is locked': | ||||
|                     self.ui.debug('', "Locked sqlite database, retrying.") | ||||
|                     success = False | ||||
|                 else: | ||||
|                     raise | ||||
|             finally: | ||||
|                 self._dblock.release() | ||||
|         return cursor | ||||
|  | ||||
|     def upgrade_db(self, from_ver): | ||||
|         """Upgrade the sqlite format from version 'from_ver' to current""" | ||||
|  | ||||
|         if hasattr(self, 'connection'): | ||||
|             self.connection.close() #close old connections first | ||||
|         self.connection = sqlite.connect(self.filename, check_same_thread = False) | ||||
|  | ||||
|         if from_ver == 0: | ||||
|             # from_ver==0: no db existent: plain text migration? | ||||
|             self.create_db() | ||||
| @@ -74,22 +123,18 @@ class LocalStatusSQLiteFolder(LocalStatusFolder): | ||||
|                                  (self.repository, self)) | ||||
|                 file = open(plaintextfilename, "rt") | ||||
|                 line = file.readline().strip() | ||||
|                 assert(line == magicline) | ||||
|                 connection = sqlite.connect(self.filename) | ||||
|                 cursor = connection.cursor() | ||||
|                 data = [] | ||||
|                 for line in file.xreadlines(): | ||||
|                     line = line.strip() | ||||
|                     uid, flags = line.split(':') | ||||
|                     uid, flags = line.strip().split(':') | ||||
|                     uid = long(uid) | ||||
|                     flags = [x for x in flags] | ||||
|                     flags.sort() | ||||
|                     flags = ''.join(flags) | ||||
|                     self.cursor.execute('INSERT INTO status (id,flags) VALUES (?,?)', | ||||
|                                         (uid,flags)) | ||||
|                 file.close() | ||||
|                     flags = list(flags) | ||||
|                     flags = ''.join(sorted(flags)) | ||||
|                     data.append((uid,flags)) | ||||
|                 self.connection.executemany('INSERT INTO status (id,flags) VALUES (?,?)', | ||||
|                                        data) | ||||
|                 self.connection.commit() | ||||
|                 file.close() | ||||
|                 os.rename(plaintextfilename, plaintextfilename + ".old") | ||||
|                 self.connection.close() | ||||
|         # Future version upgrades come here... | ||||
|         # if from_ver <= 1: ... #upgrade from 1 to 2 | ||||
|         # if from_ver <= 2: ... #upgrade from 2 to 3 | ||||
| @@ -98,12 +143,15 @@ class LocalStatusSQLiteFolder(LocalStatusFolder): | ||||
|         """Create a new db file""" | ||||
|         self.ui._msg('Creating new Local Status db for %s:%s' \ | ||||
|                          % (self.repository, self)) | ||||
|         connection = sqlite.connect(self.filename) | ||||
|         cursor = connection.cursor() | ||||
|         cursor.execute('CREATE TABLE metadata (key VARCHAR(50) PRIMARY KEY, value VARCHAR(128))') | ||||
|         cursor.execute("INSERT INTO metadata VALUES('db_version', '1')") | ||||
|         cursor.execute('CREATE TABLE status (id INTEGER PRIMARY KEY, flags VARCHAR(50))') | ||||
|         self.save() #commit if needed | ||||
|         if hasattr(self, 'connection'): | ||||
|             self.connection.close() #close old connections first | ||||
|         self.connection = sqlite.connect(self.filename, check_same_thread = False) | ||||
|         self.connection.executescript(""" | ||||
|         CREATE TABLE metadata (key VARCHAR(50) PRIMARY KEY, value VARCHAR(128)); | ||||
|         INSERT INTO metadata VALUES('db_version', '1'); | ||||
|         CREATE TABLE status (id INTEGER PRIMARY KEY, flags VARCHAR(50)); | ||||
|         """) | ||||
|         self.connection.commit() | ||||
|  | ||||
|     def isnewfolder(self): | ||||
|         # testing the existence of the db file won't work. It is created | ||||
| @@ -113,12 +161,12 @@ class LocalStatusSQLiteFolder(LocalStatusFolder): | ||||
|  | ||||
|     def deletemessagelist(self): | ||||
|         """delete all messages in the db""" | ||||
|         self.cursor.execute('DELETE FROM status') | ||||
|         self.sql_write('DELETE FROM status') | ||||
|  | ||||
|     def cachemessagelist(self): | ||||
|         self.messagelist = {} | ||||
|         self.cursor.execute('SELECT id,flags from status') | ||||
|         for row in self.cursor: | ||||
|         cursor = self.connection.execute('SELECT id,flags from status') | ||||
|         for row in cursor: | ||||
|                 flags = [x for x in row[1]] | ||||
|                 self.messagelist[row[0]] = {'uid': row[0], 'flags': flags} | ||||
|  | ||||
| @@ -126,24 +174,41 @@ class LocalStatusSQLiteFolder(LocalStatusFolder): | ||||
|         #Noop in this backend | ||||
|         pass | ||||
|  | ||||
|     def uidexists(self,uid): | ||||
|         self.cursor.execute('SELECT id FROM status WHERE id=:id',{'id': uid}) | ||||
|         for row in self.cursor: | ||||
|             if(row[0]==uid): | ||||
|                 return 1 | ||||
|         return 0 | ||||
|  | ||||
|     def getmessageuidlist(self): | ||||
|         self.cursor.execute('SELECT id from status') | ||||
|         r = [] | ||||
|         for row in self.cursor: | ||||
|             r.append(row[0]) | ||||
|         return r | ||||
|  | ||||
|     def getmessagecount(self): | ||||
|         self.cursor.execute('SELECT count(id) from status'); | ||||
|         row = self.cursor.fetchone() | ||||
|         return row[0] | ||||
|     # Following some pure SQLite functions, where we chose to use | ||||
|     # BaseFolder() methods instead. Doing those on the in-memory list is | ||||
|     # quicker anyway. If our db becomes so big that we don't want to | ||||
|     # maintain the in-memory list anymore, these might come in handy | ||||
|     # in the future though. | ||||
|     # | ||||
|     #def uidexists(self,uid): | ||||
|     #    conn, cursor = self.get_cursor() | ||||
|     #    with conn: | ||||
|     #        cursor.execute('SELECT id FROM status WHERE id=:id',{'id': uid}) | ||||
|     #        return cursor.fetchone() | ||||
|     # This would be the pure SQLite solution, use BaseFolder() method, | ||||
|     # to avoid threading with sqlite... | ||||
|     #def getmessageuidlist(self): | ||||
|     #    conn, cursor = self.get_cursor() | ||||
|     #    with conn: | ||||
|     #        cursor.execute('SELECT id from status') | ||||
|     #        r = [] | ||||
|     #        for row in cursor: | ||||
|     #            r.append(row[0]) | ||||
|     #        return r | ||||
|     #def getmessagecount(self): | ||||
|     #    conn, cursor = self.get_cursor() | ||||
|     #    with conn: | ||||
|     #        cursor.execute('SELECT count(id) from status'); | ||||
|     #        return cursor.fetchone()[0] | ||||
|     #def getmessageflags(self, uid): | ||||
|     #    conn, cursor = self.get_cursor() | ||||
|     #    with conn: | ||||
|     #        cursor.execute('SELECT flags FROM status WHERE id=:id', | ||||
|     #                        {'id': uid}) | ||||
|     #        for row in cursor: | ||||
|     #            flags = [x for x in row[0]] | ||||
|     #            return flags | ||||
|     #        assert False,"getmessageflags() called on non-existing message" | ||||
|  | ||||
|     def savemessage(self, uid, content, flags, rtime): | ||||
|         if uid < 0: | ||||
| @@ -155,38 +220,23 @@ class LocalStatusSQLiteFolder(LocalStatusFolder): | ||||
|             return uid | ||||
|  | ||||
|         self.messagelist[uid] = {'uid': uid, 'flags': flags, 'time': rtime} | ||||
|         flags.sort() | ||||
|         flags = ''.join(flags) | ||||
|         self.cursor.execute('INSERT INTO status (id,flags) VALUES (?,?)', | ||||
|         flags = ''.join(sorted(flags)) | ||||
|         self.sql_write('INSERT INTO status (id,flags) VALUES (?,?)', | ||||
|                          (uid,flags)) | ||||
|         self.save() | ||||
|         return uid | ||||
|  | ||||
|     def getmessageflags(self, uid): | ||||
|         self.cursor.execute('SELECT flags FROM status WHERE id=:id', | ||||
|                             {'id': uid}) | ||||
|         for row in self.cursor: | ||||
|             flags = [x for x in row[0]] | ||||
|             return flags | ||||
|         assert False,"getmessageflags() called on non-existing message" | ||||
|  | ||||
|     def getmessagetime(self, uid): | ||||
|         return self.messagelist[uid]['time'] | ||||
|  | ||||
|     def savemessageflags(self, uid, flags): | ||||
|         self.messagelist[uid] = {'uid': uid, 'flags': flags} | ||||
|         flags.sort() | ||||
|         flags = ''.join(flags) | ||||
|         self.cursor.execute('UPDATE status SET flags=? WHERE id=?',(flags,uid)) | ||||
|         self.save() | ||||
|         self.sql_write('UPDATE status SET flags=? WHERE id=?',(flags,uid)) | ||||
|  | ||||
|     def deletemessages(self, uidlist): | ||||
|         # Weed out ones not in self.messagelist | ||||
|         uidlist = [uid for uid in uidlist if uid in self.messagelist] | ||||
|         if not len(uidlist): | ||||
|             return | ||||
|  | ||||
|         for uid in uidlist: | ||||
|             del(self.messagelist[uid]) | ||||
|             #if self.uidexists(uid): | ||||
|             self.cursor.execute('DELETE FROM status WHERE id=:id', {'id': uid}) | ||||
|             self.sql_write('DELETE FROM status WHERE id=?', | ||||
|                              uidlist) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user