docker-offlineimap/rfc6555.py
Rodolfo García Peñas (kix) fb43b31e7c Included external libraries
I included these libraries here, to avoid problems sharing with other
users.

I will remove them later.
2020-08-28 23:15:13 +02:00

316 lines
10 KiB
Python

""" Python implementation of the Happy Eyeballs Algorithm described in RFC 6555. """
# Copyright 2017 Seth Michael Larson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import errno
import socket
from selectors2 import DefaultSelector, EVENT_WRITE
# time.perf_counter() is defined in Python 3.3
try:
from time import perf_counter
except (ImportError, AttributeError):
from time import time as perf_counter
# This list is due to socket.error and IOError not being a
# subclass of OSError until later versions of Python.
_SOCKET_ERRORS = (socket.error, OSError, IOError)
# Detects whether an IPv6 socket can be allocated.
def _detect_ipv6():
if getattr(socket, 'has_ipv6', False) and hasattr(socket, 'AF_INET6'):
_sock = None
try:
_sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
_sock.bind(('::1', 0))
return True
except _SOCKET_ERRORS:
if _sock:
_sock.close()
return False
_HAS_IPV6 = _detect_ipv6()
# These are error numbers for asynchronous operations which can
# be safely ignored by RFC 6555 as being non-errors.
_ASYNC_ERRNOS = set([errno.EINPROGRESS,
errno.EAGAIN,
errno.EWOULDBLOCK])
if hasattr(errno, 'WSAWOULDBLOCK'):
_ASYNC_ERRNOS.add(errno.WSAWOULDBLOCK)
_DEFAULT_CACHE_DURATION = 60 * 10 # 10 minutes according to the RFC.
# This value that can be used to disable RFC 6555 globally.
RFC6555_ENABLED = _HAS_IPV6
__all__ = ['RFC6555_ENABLED',
'create_connection',
'cache']
__version__ = '0.0.0'
__author__ = 'Seth Michael Larson'
__email__ = 'sethmichaellarson@protonmail.com'
__license__ = 'Apache-2.0'
class _RFC6555CacheManager(object):
def __init__(self):
self.validity_duration = _DEFAULT_CACHE_DURATION
self.enabled = True
self.entries = {}
def add_entry(self, address, family):
if self.enabled:
current_time = perf_counter()
# Don't over-write old entries to reset their expiry.
if address not in self.entries or self.entries[address][1] > current_time:
self.entries[address] = (family, current_time + self.validity_duration)
def get_entry(self, address):
if not self.enabled or address not in self.entries:
return None
family, expiry = self.entries[address]
if perf_counter() > expiry:
del self.entries[address]
return None
return family
cache = _RFC6555CacheManager()
class _RFC6555ConnectionManager(object):
def __init__(self, address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
self.address = address
self.timeout = timeout
self.source_address = source_address
self._error = None
self._selector = DefaultSelector()
self._sockets = []
self._start_time = None
def create_connection(self):
self._start_time = perf_counter()
host, port = self.address
addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
ret = self._connect_with_cached_family(addr_info)
# If it's a list, then these are the remaining values to try.
if isinstance(ret, list):
addr_info = ret
else:
cache.add_entry(self.address, ret.family)
return ret
# If we don't get any results back then just skip to the end.
if not addr_info:
raise socket.error('getaddrinfo returns an empty list')
sock = self._attempt_connect_with_addr_info(addr_info)
if sock:
cache.add_entry(self.address, sock.family)
return sock
elif self._error:
raise self._error
else:
raise socket.timeout()
def _attempt_connect_with_addr_info(self, addr_info):
sock = None
try:
for family, socktype, proto, _, sockaddr in addr_info:
self._create_socket(family, socktype, proto, sockaddr)
sock = self._wait_for_connection(False)
if sock:
break
if sock is None:
sock = self._wait_for_connection(True)
finally:
self._remove_all_sockets()
return sock
def _connect_with_cached_family(self, addr_info):
family = cache.get_entry(self.address)
if family is None:
return addr_info
is_family = []
not_family = []
for value in addr_info:
if value[0] == family:
is_family.append(value)
else:
not_family.append(value)
sock = self._attempt_connect_with_addr_info(is_family)
if sock is not None:
return sock
return not_family
def _create_socket(self, family, socktype, proto, sockaddr):
sock = None
try:
sock = socket.socket(family, socktype, proto)
# If we're using the 'default' socket timeout we have
# to set it to a real value here as this is the earliest
# opportunity to without pre-allocating a socket just for
# this purpose.
if self.timeout is socket._GLOBAL_DEFAULT_TIMEOUT:
self.timeout = sock.gettimeout()
if self.source_address:
sock.bind(self.source_address)
# Make the socket non-blocking so we can use our selector.
sock.settimeout(0.0)
if self._is_acceptable_errno(sock.connect_ex(sockaddr)):
self._selector.register(sock, EVENT_WRITE)
self._sockets.append(sock)
except _SOCKET_ERRORS as e:
self._error = e
if sock is not None:
_RFC6555ConnectionManager._close_socket(sock)
def _wait_for_connection(self, last_wait):
self._remove_all_errored_sockets()
# This is a safe-guard to make sure sock.gettimeout() is called in the
# case that the default socket timeout is used. If there are no
# sockets then we may not have called sock.gettimeout() yet.
if not self._sockets:
return None
# If this is the last time we're waiting for connections
# then we should wait until we should raise a timeout
# error, otherwise we should only wait >0.2 seconds as
# recommended by RFC 6555.
if last_wait:
if self.timeout is None:
select_timeout = None
else:
select_timeout = self._get_remaining_time()
else:
select_timeout = self._get_select_time()
# Wait for any socket to become writable as a sign of being connected.
for key, _ in self._selector.select(select_timeout):
sock = key.fileobj
if not self._is_socket_errored(sock):
# Restore the old proper timeout of the socket.
sock.settimeout(self.timeout)
# Remove it from this list to exempt the socket from cleanup.
self._sockets.remove(sock)
self._selector.unregister(sock)
return sock
return None
def _get_remaining_time(self):
if self.timeout is None:
return None
return max(self.timeout - (perf_counter() - self._start_time), 0.0)
def _get_select_time(self):
if self.timeout is None:
return 0.2
return min(0.2, self._get_remaining_time())
def _remove_all_errored_sockets(self):
socks = []
for sock in self._sockets:
if self._is_socket_errored(sock):
socks.append(sock)
for sock in socks:
self._selector.unregister(sock)
self._sockets.remove(sock)
_RFC6555ConnectionManager._close_socket(sock)
@staticmethod
def _close_socket(sock):
try:
sock.close()
except _SOCKET_ERRORS:
pass
def _is_acceptable_errno(self, errno):
if errno == 0 or errno in _ASYNC_ERRNOS:
return True
self._error = socket.error()
self._error.errno = errno
return False
def _is_socket_errored(self, sock):
errno = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
return not self._is_acceptable_errno(errno)
def _remove_all_sockets(self):
for sock in self._sockets:
self._selector.unregister(sock)
_RFC6555ConnectionManager._close_socket(sock)
self._sockets = []
def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
if RFC6555_ENABLED and _HAS_IPV6:
manager = _RFC6555ConnectionManager(address, timeout, source_address)
return manager.create_connection()
else:
# This code is the same as socket.create_connection() but is
# here to make sure the same code is used across all Python versions as
# the source_address parameter was added to socket.create_connection() in 3.2
# This segment of code is licensed under the Python Software Foundation License
# See LICENSE: https://github.com/python/cpython/blob/3.6/LICENSE
host, port = address
err = None
for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket.socket(af, socktype, proto)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except socket.error as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise socket.error("getaddrinfo returns an empty list")