diff --git a/offlineimap/folder/LocalStatusSQLite.py b/offlineimap/folder/LocalStatusSQLite.py index 059f741..b1246d4 100644 --- a/offlineimap/folder/LocalStatusSQLite.py +++ b/offlineimap/folder/LocalStatusSQLite.py @@ -41,6 +41,9 @@ class LocalStatusSQLiteFolder(BaseFolder): # Current version of our db format. cur_version = 2 + # Keep track on how many threads need access to the database. + threads_open_count = 0 + threads_open_lock = Lock() def __init__(self, name, repository): self.sep = '.' # Needs to be set before super.__init__() @@ -62,34 +65,39 @@ class LocalStatusSQLiteFolder(BaseFolder): self.connection = None def openfiles(self): - # Try to establish connection, no need for threadsafety in __init__. - try: - self.connection = sqlite.connect(self.filename, check_same_thread=False) - except sqlite.OperationalError as e: - # Operation had failed. - six.reraise(UserWarning, - UserWarning( - "cannot open database file '%s': %s.\nYou might " - "want to check the rights to that file and if it " - "cleanly opens with the 'sqlite<3>' command."% - (self.filename, e)), - exc_info()[2]) + # Protect the creation/upgrade of database accross threads. + with LocalStatusSQLiteFolder.threads_open_lock: + # Try to establish connection, no need for threadsafety in __init__. - # Make sure sqlite is in multithreading SERIALIZE mode. - assert sqlite.threadsafety == 1, 'Your sqlite is not multithreading safe.' + try: + self.connection = sqlite.connect(self.filename, + check_same_thread=False) + LocalStatusSQLiteFolder.threads_open_count += 1 + except sqlite.OperationalError as e: + # Operation had failed. + six.reraise(UserWarning, + UserWarning( + "cannot open database file '%s': %s.\nYou might" + " want to check the rights to that file and if " + "it cleanly opens with the 'sqlite<3>' command"% + (self.filename, e)), + exc_info()[2]) - # Test if db version is current enough and if db is readable. - try: - cursor = self.connection.execute( - "SELECT value from metadata WHERE key='db_version'") - except sqlite.DatabaseError: - # db file missing or corrupt, recreate it. - self.__create_db() - else: - # Fetch db version and upgrade if needed. - version = int(cursor.fetchone()[0]) - if version < LocalStatusSQLiteFolder.cur_version: - self.__upgrade_db(version) + # Make sure sqlite is in multithreading SERIALIZE mode. + assert sqlite.threadsafety == 1, 'Your sqlite is not multithreading safe.' + + # Test if db version is current enough and if db is readable. + try: + cursor = self.connection.execute( + "SELECT value from metadata WHERE key='db_version'") + except sqlite.DatabaseError: + # db file missing or corrupt, recreate it. + self.__create_db() + else: + # Fetch db version and upgrade if needed. + version = int(cursor.fetchone()[0]) + if version < LocalStatusSQLiteFolder.cur_version: + self.__upgrade_db(version) def purge(self): """Remove any pre-existing database.""" @@ -159,8 +167,8 @@ class LocalStatusSQLiteFolder(BaseFolder): def __upgrade_db(self, from_ver): """Upgrade the sqlite format from version 'from_ver' to current""" - if hasattr(self, 'connection'): - self.connection.close() #close old connections first + if self.connection is not None: + self.connection.close() # Close old connections first. self.connection = sqlite.connect(self.filename, check_same_thread=False) @@ -185,8 +193,8 @@ class LocalStatusSQLiteFolder(BaseFolder): self.connection must point to the opened and valid SQlite database connection.""" - self.ui._msg('Creating new Local Status db for %s:%s' \ - % (self.repository, self)) + self.ui._msg('Creating new Local Status db for %s:%s'% + (self.repository, self)) self.connection.executescript(""" CREATE TABLE metadata (key VARCHAR(50) PRIMARY KEY, value VARCHAR(128)); INSERT INTO metadata VALUES('db_version', '2'); @@ -229,10 +237,13 @@ class LocalStatusSQLiteFolder(BaseFolder): self.messagelist[uid]['mtime'] = row[2] def closefiles(self): - try: - self.connection.close() - except: - pass + with LocalStatusSQLiteFolder.threads_open_lock: + LocalStatusSQLiteFolder.threads_open_count -= 1 + if self.threads_open_count < 1: + try: + self.connection.close() + except: + pass # Interface from LocalStatusFolder def save(self):