Use context manager for locking
This commit is contained in:
		| @@ -282,45 +282,37 @@ class Application: | ||||
|         is_authenticated = self.is_authenticated(user, password) | ||||
|         is_valid_user = is_authenticated or not user | ||||
|  | ||||
|         lock = None | ||||
|         try: | ||||
|             if is_valid_user: | ||||
|                 if function in (self.do_GET, self.do_HEAD, | ||||
|                                 self.do_OPTIONS, self.do_PROPFIND, | ||||
|                                 self.do_REPORT): | ||||
|                     lock_mode = "r" | ||||
|                 else: | ||||
|                     lock_mode = "w" | ||||
|                 lock = self.Collection.acquire_lock(lock_mode) | ||||
|         # Get content | ||||
|         content_length = int(environ.get("CONTENT_LENGTH") or 0) | ||||
|         if content_length: | ||||
|             content = self.decode( | ||||
|                 environ["wsgi.input"].read(content_length), environ) | ||||
|             self.logger.debug("Request content:\n%s" % content) | ||||
|         else: | ||||
|             content = None | ||||
|  | ||||
|         if is_valid_user: | ||||
|             if function in (self.do_GET, self.do_HEAD, | ||||
|                             self.do_OPTIONS, self.do_PROPFIND, | ||||
|                             self.do_REPORT): | ||||
|                 lock_mode = "r" | ||||
|             else: | ||||
|                 lock_mode = "w" | ||||
|             with self.Collection.acquire_lock(lock_mode): | ||||
|                 items = self.Collection.discover( | ||||
|                     path, environ.get("HTTP_DEPTH", "0")) | ||||
|                 read_allowed_items, write_allowed_items = ( | ||||
|                     self.collect_allowed_items(items, user)) | ||||
|             else: | ||||
|                 read_allowed_items, write_allowed_items = None, None | ||||
|  | ||||
|             # Get content | ||||
|             content_length = int(environ.get("CONTENT_LENGTH") or 0) | ||||
|             if content_length: | ||||
|                 content = self.decode( | ||||
|                     environ["wsgi.input"].read(content_length), environ) | ||||
|                 self.logger.debug("Request content:\n%s" % content) | ||||
|             else: | ||||
|                 content = None | ||||
|  | ||||
|             if is_valid_user and ( | ||||
|                     (read_allowed_items or write_allowed_items) or | ||||
|                     (is_authenticated and function == self.do_PROPFIND) or | ||||
|                     function == self.do_OPTIONS): | ||||
|                 status, headers, answer = function( | ||||
|                     environ, read_allowed_items, write_allowed_items, content, | ||||
|                     user) | ||||
|             else: | ||||
|                 status, headers, answer = NOT_ALLOWED | ||||
|         finally: | ||||
|             if lock: | ||||
|                 lock.release() | ||||
|                 if (read_allowed_items or write_allowed_items or | ||||
|                         is_authenticated and function == self.do_PROPFIND or | ||||
|                         function == self.do_OPTIONS): | ||||
|                     status, headers, answer = function( | ||||
|                         environ, read_allowed_items, write_allowed_items, | ||||
|                         content, user) | ||||
|                 else: | ||||
|                     status, headers, answer = NOT_ALLOWED | ||||
|         else: | ||||
|             status, headers, answer = NOT_ALLOWED | ||||
|  | ||||
|         if (status, headers, answer) == NOT_ALLOWED and not is_authenticated: | ||||
|             # Unknown or unauthorized user | ||||
|   | ||||
| @@ -277,14 +277,13 @@ class BaseCollection: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @classmethod | ||||
|     @contextmanager | ||||
|     def acquire_lock(cls, mode): | ||||
|         """Lock the whole storage. | ||||
|         """Set a context manager to lock the whole storage. | ||||
|  | ||||
|         ``mode`` must either be "r" for shared access or "w" for exclusive | ||||
|         access. | ||||
|  | ||||
|         Returns an object which has a method ``release``. | ||||
|  | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
| @@ -521,6 +520,7 @@ class Collection(BaseCollection): | ||||
|     _lock = threading.Lock() | ||||
|  | ||||
|     @classmethod | ||||
|     @contextmanager | ||||
|     def acquire_lock(cls, mode): | ||||
|         class Lock: | ||||
|             def __init__(self, release_method): | ||||
| @@ -574,4 +574,5 @@ class Collection(BaseCollection): | ||||
|             # TODO: use readers–writer lock | ||||
|             cls._lock.acquire() | ||||
|             lock = Lock(cls._lock.release) | ||||
|         return lock | ||||
|         yield | ||||
|         lock.release() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Unrud
					Unrud