Factor out SQL retries
Test if sqlite is multithreading-safe and bail out if not. sqlite versions since at least 2008 are. But, as it still causes errors when 2 threads try to write to the same connection simultanously (We get a "cannot start transaction within a transaction" error), we protect writes with a per class, ie per-connection lock. Factor out the retrying to write when the database is locked. Signed-off-by: Sebastian Spaeth <Sebastian@SSpaeth.de> Signed-off-by: Nicolas Sebrecht <nicolas.s-dev@laposte.net>
This commit is contained in:
		
				
					committed by
					
						
						Nicolas Sebrecht
					
				
			
			
				
	
			
			
			
						parent
						
							af25c2779f
						
					
				
				
					commit
					0af9ef70a7
				
			@@ -16,15 +16,29 @@
 | 
			
		||||
#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301 USA
 | 
			
		||||
import os.path
 | 
			
		||||
import re
 | 
			
		||||
from threading import Lock
 | 
			
		||||
from LocalStatus import LocalStatusFolder, magicline
 | 
			
		||||
try:
 | 
			
		||||
    from pysqlite2 import dbapi2 as sqlite
 | 
			
		||||
    import sqlite3 as sqlite
 | 
			
		||||
except:
 | 
			
		||||
    pass #fail only if needed later on, not on import
 | 
			
		||||
 | 
			
		||||
class LocalStatusSQLiteFolder(LocalStatusFolder):
 | 
			
		||||
    """LocalStatus backend implemented with an SQLite database"""
 | 
			
		||||
    #current version of the db format
 | 
			
		||||
    """LocalStatus backend implemented with an SQLite database
 | 
			
		||||
 | 
			
		||||
    As python-sqlite currently does not allow to access the same sqlite
 | 
			
		||||
    objects from various threads, we need to open get and close a db
 | 
			
		||||
    connection and cursor for all operations. This is a big disadvantage
 | 
			
		||||
    and we might want to investigate if we cannot hold an object open
 | 
			
		||||
    for a thread somehow."""
 | 
			
		||||
    #though. According to sqlite docs, you need to commit() before
 | 
			
		||||
    #the connection is closed or your changes will be lost!"""
 | 
			
		||||
    #get db connection which autocommits
 | 
			
		||||
    #connection = sqlite.connect(self.filename, isolation_level=None)        
 | 
			
		||||
    #cursor = connection.cursor()
 | 
			
		||||
    #return connection, cursor
 | 
			
		||||
 | 
			
		||||
    #current version of our db format
 | 
			
		||||
    cur_version = 1
 | 
			
		||||
 | 
			
		||||
    def __init__(self, root, name, repository, accountname, config):
 | 
			
		||||
@@ -32,33 +46,68 @@ class LocalStatusSQLiteFolder(LocalStatusFolder):
 | 
			
		||||
                                                      repository, 
 | 
			
		||||
                                                      accountname,
 | 
			
		||||
                                                      config)       
 | 
			
		||||
        #Try to establish connection
 | 
			
		||||
 | 
			
		||||
        # dblock protects against concurrent writes in same connection
 | 
			
		||||
        self._dblock = Lock()
 | 
			
		||||
        #Try to establish connection, no need for threadsafety in __init__
 | 
			
		||||
        try:
 | 
			
		||||
            self.connection = sqlite.connect(self.filename)
 | 
			
		||||
            self.connection = sqlite.connect(self.filename, check_same_thread = False)
 | 
			
		||||
        except NameError:
 | 
			
		||||
            # sqlite import had failed
 | 
			
		||||
            raise UserWarning('SQLite backend chosen, but no sqlite python '
 | 
			
		||||
                              'bindings available. Please install.')
 | 
			
		||||
 | 
			
		||||
        #Test if the db version is current enough and if the db is
 | 
			
		||||
        #readable.
 | 
			
		||||
        #Make sure sqlite is in multithreading SERIALIZE mode
 | 
			
		||||
        assert sqlite.threadsafety == 1, 'Your sqlite is not multithreading safe.'
 | 
			
		||||
 | 
			
		||||
        #Test if db version is current enough and if db is readable.
 | 
			
		||||
        try:
 | 
			
		||||
            self.cursor = self.connection.cursor()
 | 
			
		||||
            self.cursor.execute("SELECT value from metadata WHERE key='db_version'")
 | 
			
		||||
            cursor = self.connection.execute("SELECT value from metadata WHERE key='db_version'")
 | 
			
		||||
        except sqlite.DatabaseError:
 | 
			
		||||
            #db file missing or corrupt, recreate it.
 | 
			
		||||
            self.connection.close()
 | 
			
		||||
            self.upgrade_db(0)
 | 
			
		||||
        else:
 | 
			
		||||
            # fetch db version and upgrade if needed
 | 
			
		||||
            version = int(self.cursor.fetchone()[0])
 | 
			
		||||
            self.cursor.close()
 | 
			
		||||
            version = int(cursor.fetchone()[0])
 | 
			
		||||
            if version < LocalStatusSQLiteFolder.cur_version:
 | 
			
		||||
                self.upgrade_db(version)
 | 
			
		||||
            self.connection.close()
 | 
			
		||||
 | 
			
		||||
    def sql_write(self, sql, vars=None):
 | 
			
		||||
        """execute some SQL retrying if the db was locked.
 | 
			
		||||
 | 
			
		||||
        :param sql: the SQL string passed to execute() :param args: the
 | 
			
		||||
            variable values to `sql`. E.g. (1,2) or {uid:1, flags:'T'}. See
 | 
			
		||||
            sqlite docs for possibilities.
 | 
			
		||||
        :returns: the Cursor() or raises an Exception"""
 | 
			
		||||
        success = False
 | 
			
		||||
        while not success:
 | 
			
		||||
            self._dblock.acquire()
 | 
			
		||||
            try:
 | 
			
		||||
                if vars is None:
 | 
			
		||||
                    cursor = self.connection.execute(sql)
 | 
			
		||||
                else:
 | 
			
		||||
                    cursor = self.connection.execute(sql, vars)
 | 
			
		||||
                success = True
 | 
			
		||||
                self.connection.commit()
 | 
			
		||||
            except sqlite.OperationalError, e:
 | 
			
		||||
                if e.args[0] == 'cannot commit - no transaction is active':
 | 
			
		||||
                    pass
 | 
			
		||||
                elif e.args[0] == 'database is locked':
 | 
			
		||||
                    self.ui.debug('', "Locked sqlite database, retrying.")
 | 
			
		||||
                    success = False
 | 
			
		||||
                else:
 | 
			
		||||
                    raise
 | 
			
		||||
            finally:
 | 
			
		||||
                self._dblock.release()
 | 
			
		||||
        return cursor
 | 
			
		||||
 | 
			
		||||
    def upgrade_db(self, from_ver):
 | 
			
		||||
        """Upgrade the sqlite format from version 'from_ver' to current"""
 | 
			
		||||
 | 
			
		||||
        if hasattr(self, 'connection'):
 | 
			
		||||
            self.connection.close() #close old connections first
 | 
			
		||||
        self.connection = sqlite.connect(self.filename, check_same_thread = False)
 | 
			
		||||
 | 
			
		||||
        if from_ver == 0:
 | 
			
		||||
            # from_ver==0: no db existent: plain text migration?
 | 
			
		||||
            self.create_db()
 | 
			
		||||
@@ -74,22 +123,18 @@ class LocalStatusSQLiteFolder(LocalStatusFolder):
 | 
			
		||||
                                 (self.repository, self))
 | 
			
		||||
                file = open(plaintextfilename, "rt")
 | 
			
		||||
                line = file.readline().strip()
 | 
			
		||||
                assert(line == magicline)
 | 
			
		||||
                connection = sqlite.connect(self.filename)
 | 
			
		||||
                cursor = connection.cursor()
 | 
			
		||||
                data = []
 | 
			
		||||
                for line in file.xreadlines():
 | 
			
		||||
                    line = line.strip()
 | 
			
		||||
                    uid, flags = line.split(':')
 | 
			
		||||
                    uid, flags = line.strip().split(':')
 | 
			
		||||
                    uid = long(uid)
 | 
			
		||||
                    flags = [x for x in flags]
 | 
			
		||||
                    flags.sort()
 | 
			
		||||
                    flags = ''.join(flags)
 | 
			
		||||
                    self.cursor.execute('INSERT INTO status (id,flags) VALUES (?,?)',
 | 
			
		||||
                                        (uid,flags))
 | 
			
		||||
                file.close()
 | 
			
		||||
                    flags = list(flags)
 | 
			
		||||
                    flags = ''.join(sorted(flags))
 | 
			
		||||
                    data.append((uid,flags))
 | 
			
		||||
                self.connection.executemany('INSERT INTO status (id,flags) VALUES (?,?)',
 | 
			
		||||
                                       data)
 | 
			
		||||
                self.connection.commit()
 | 
			
		||||
                file.close()
 | 
			
		||||
                os.rename(plaintextfilename, plaintextfilename + ".old")
 | 
			
		||||
                self.connection.close()
 | 
			
		||||
        # Future version upgrades come here...
 | 
			
		||||
        # if from_ver <= 1: ... #upgrade from 1 to 2
 | 
			
		||||
        # if from_ver <= 2: ... #upgrade from 2 to 3
 | 
			
		||||
@@ -98,12 +143,15 @@ class LocalStatusSQLiteFolder(LocalStatusFolder):
 | 
			
		||||
        """Create a new db file"""
 | 
			
		||||
        self.ui._msg('Creating new Local Status db for %s:%s' \
 | 
			
		||||
                         % (self.repository, self))
 | 
			
		||||
        connection = sqlite.connect(self.filename)
 | 
			
		||||
        cursor = connection.cursor()
 | 
			
		||||
        cursor.execute('CREATE TABLE metadata (key VARCHAR(50) PRIMARY KEY, value VARCHAR(128))')
 | 
			
		||||
        cursor.execute("INSERT INTO metadata VALUES('db_version', '1')")
 | 
			
		||||
        cursor.execute('CREATE TABLE status (id INTEGER PRIMARY KEY, flags VARCHAR(50))')
 | 
			
		||||
        self.save() #commit if needed
 | 
			
		||||
        if hasattr(self, 'connection'):
 | 
			
		||||
            self.connection.close() #close old connections first
 | 
			
		||||
        self.connection = sqlite.connect(self.filename, check_same_thread = False)
 | 
			
		||||
        self.connection.executescript("""
 | 
			
		||||
        CREATE TABLE metadata (key VARCHAR(50) PRIMARY KEY, value VARCHAR(128));
 | 
			
		||||
        INSERT INTO metadata VALUES('db_version', '1');
 | 
			
		||||
        CREATE TABLE status (id INTEGER PRIMARY KEY, flags VARCHAR(50));
 | 
			
		||||
        """)
 | 
			
		||||
        self.connection.commit()
 | 
			
		||||
 | 
			
		||||
    def isnewfolder(self):
 | 
			
		||||
        # testing the existence of the db file won't work. It is created
 | 
			
		||||
@@ -113,37 +161,54 @@ class LocalStatusSQLiteFolder(LocalStatusFolder):
 | 
			
		||||
 | 
			
		||||
    def deletemessagelist(self):
 | 
			
		||||
        """delete all messages in the db"""
 | 
			
		||||
        self.cursor.execute('DELETE FROM status')
 | 
			
		||||
        self.sql_write('DELETE FROM status')
 | 
			
		||||
 | 
			
		||||
    def cachemessagelist(self):
 | 
			
		||||
        self.messagelist = {}
 | 
			
		||||
        self.cursor.execute('SELECT id,flags from status')
 | 
			
		||||
        for row in self.cursor:
 | 
			
		||||
            flags = [x for x in row[1]]
 | 
			
		||||
            self.messagelist[row[0]] = {'uid': row[0], 'flags': flags}
 | 
			
		||||
        cursor = self.connection.execute('SELECT id,flags from status')
 | 
			
		||||
        for row in cursor:
 | 
			
		||||
                flags = [x for x in row[1]]
 | 
			
		||||
                self.messagelist[row[0]] = {'uid': row[0], 'flags': flags}
 | 
			
		||||
 | 
			
		||||
    def save(self):
 | 
			
		||||
        #Noop in this backend
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def uidexists(self,uid):
 | 
			
		||||
        self.cursor.execute('SELECT id FROM status WHERE id=:id',{'id': uid})
 | 
			
		||||
        for row in self.cursor:
 | 
			
		||||
            if(row[0]==uid):
 | 
			
		||||
                return 1
 | 
			
		||||
        return 0
 | 
			
		||||
 | 
			
		||||
    def getmessageuidlist(self):
 | 
			
		||||
        self.cursor.execute('SELECT id from status')
 | 
			
		||||
        r = []
 | 
			
		||||
        for row in self.cursor:
 | 
			
		||||
            r.append(row[0])
 | 
			
		||||
        return r
 | 
			
		||||
 | 
			
		||||
    def getmessagecount(self):
 | 
			
		||||
        self.cursor.execute('SELECT count(id) from status');
 | 
			
		||||
        row = self.cursor.fetchone()
 | 
			
		||||
        return row[0]
 | 
			
		||||
    # Following some pure SQLite functions, where we chose to use
 | 
			
		||||
    # BaseFolder() methods instead. Doing those on the in-memory list is
 | 
			
		||||
    # quicker anyway. If our db becomes so big that we don't want to
 | 
			
		||||
    # maintain the in-memory list anymore, these might come in handy
 | 
			
		||||
    # in the future though.
 | 
			
		||||
    #
 | 
			
		||||
    #def uidexists(self,uid):
 | 
			
		||||
    #    conn, cursor = self.get_cursor()
 | 
			
		||||
    #    with conn:
 | 
			
		||||
    #        cursor.execute('SELECT id FROM status WHERE id=:id',{'id': uid})
 | 
			
		||||
    #        return cursor.fetchone()
 | 
			
		||||
    # This would be the pure SQLite solution, use BaseFolder() method,
 | 
			
		||||
    # to avoid threading with sqlite...
 | 
			
		||||
    #def getmessageuidlist(self):
 | 
			
		||||
    #    conn, cursor = self.get_cursor()
 | 
			
		||||
    #    with conn:
 | 
			
		||||
    #        cursor.execute('SELECT id from status')
 | 
			
		||||
    #        r = []
 | 
			
		||||
    #        for row in cursor:
 | 
			
		||||
    #            r.append(row[0])
 | 
			
		||||
    #        return r
 | 
			
		||||
    #def getmessagecount(self):
 | 
			
		||||
    #    conn, cursor = self.get_cursor()
 | 
			
		||||
    #    with conn:
 | 
			
		||||
    #        cursor.execute('SELECT count(id) from status');
 | 
			
		||||
    #        return cursor.fetchone()[0]
 | 
			
		||||
    #def getmessageflags(self, uid):
 | 
			
		||||
    #    conn, cursor = self.get_cursor()
 | 
			
		||||
    #    with conn:
 | 
			
		||||
    #        cursor.execute('SELECT flags FROM status WHERE id=:id',
 | 
			
		||||
    #                        {'id': uid})
 | 
			
		||||
    #        for row in cursor:
 | 
			
		||||
    #            flags = [x for x in row[0]]
 | 
			
		||||
    #            return flags
 | 
			
		||||
    #        assert False,"getmessageflags() called on non-existing message"
 | 
			
		||||
 | 
			
		||||
    def savemessage(self, uid, content, flags, rtime):
 | 
			
		||||
        if uid < 0:
 | 
			
		||||
@@ -155,38 +220,23 @@ class LocalStatusSQLiteFolder(LocalStatusFolder):
 | 
			
		||||
            return uid
 | 
			
		||||
 | 
			
		||||
        self.messagelist[uid] = {'uid': uid, 'flags': flags, 'time': rtime}
 | 
			
		||||
        flags.sort()
 | 
			
		||||
        flags = ''.join(flags)
 | 
			
		||||
        self.cursor.execute('INSERT INTO status (id,flags) VALUES (?,?)',
 | 
			
		||||
                            (uid,flags))
 | 
			
		||||
        self.save()
 | 
			
		||||
        flags = ''.join(sorted(flags))
 | 
			
		||||
        self.sql_write('INSERT INTO status (id,flags) VALUES (?,?)',
 | 
			
		||||
                         (uid,flags))
 | 
			
		||||
        return uid
 | 
			
		||||
 | 
			
		||||
    def getmessageflags(self, uid):
 | 
			
		||||
        self.cursor.execute('SELECT flags FROM status WHERE id=:id',
 | 
			
		||||
                            {'id': uid})
 | 
			
		||||
        for row in self.cursor:
 | 
			
		||||
            flags = [x for x in row[0]]
 | 
			
		||||
            return flags
 | 
			
		||||
        assert False,"getmessageflags() called on non-existing message"
 | 
			
		||||
 | 
			
		||||
    def getmessagetime(self, uid):
 | 
			
		||||
        return self.messagelist[uid]['time']
 | 
			
		||||
 | 
			
		||||
    def savemessageflags(self, uid, flags):
 | 
			
		||||
        self.messagelist[uid] = {'uid': uid, 'flags': flags}
 | 
			
		||||
        flags.sort()
 | 
			
		||||
        flags = ''.join(flags)
 | 
			
		||||
        self.cursor.execute('UPDATE status SET flags=? WHERE id=?',(flags,uid))
 | 
			
		||||
        self.save()
 | 
			
		||||
        self.sql_write('UPDATE status SET flags=? WHERE id=?',(flags,uid))
 | 
			
		||||
 | 
			
		||||
    def deletemessages(self, uidlist):
 | 
			
		||||
        # Weed out ones not in self.messagelist
 | 
			
		||||
        uidlist = [uid for uid in uidlist if uid in self.messagelist]
 | 
			
		||||
        if not len(uidlist):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        for uid in uidlist:
 | 
			
		||||
            del(self.messagelist[uid])
 | 
			
		||||
            #if self.uidexists(uid):
 | 
			
		||||
            self.cursor.execute('DELETE FROM status WHERE id=:id', {'id': uid})
 | 
			
		||||
            self.sql_write('DELETE FROM status WHERE id=?',
 | 
			
		||||
                             uidlist)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user