fb43b31e7c
I included these libraries here, to avoid problems sharing with other users. I will remove them later.
316 lines
10 KiB
Python
316 lines
10 KiB
Python
""" Python implementation of the Happy Eyeballs Algorithm described in RFC 6555. """
|
|
|
|
# Copyright 2017 Seth Michael Larson
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import errno
|
|
import socket
|
|
from selectors2 import DefaultSelector, EVENT_WRITE
|
|
|
|
# time.perf_counter() is defined in Python 3.3
|
|
try:
|
|
from time import perf_counter
|
|
except (ImportError, AttributeError):
|
|
from time import time as perf_counter
|
|
|
|
|
|
# This list is due to socket.error and IOError not being a
|
|
# subclass of OSError until later versions of Python.
|
|
_SOCKET_ERRORS = (socket.error, OSError, IOError)
|
|
|
|
|
|
# Detects whether an IPv6 socket can be allocated.
|
|
def _detect_ipv6():
|
|
if getattr(socket, 'has_ipv6', False) and hasattr(socket, 'AF_INET6'):
|
|
_sock = None
|
|
try:
|
|
_sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
|
_sock.bind(('::1', 0))
|
|
return True
|
|
except _SOCKET_ERRORS:
|
|
if _sock:
|
|
_sock.close()
|
|
return False
|
|
|
|
|
|
_HAS_IPV6 = _detect_ipv6()
|
|
|
|
# These are error numbers for asynchronous operations which can
|
|
# be safely ignored by RFC 6555 as being non-errors.
|
|
_ASYNC_ERRNOS = set([errno.EINPROGRESS,
|
|
errno.EAGAIN,
|
|
errno.EWOULDBLOCK])
|
|
if hasattr(errno, 'WSAWOULDBLOCK'):
|
|
_ASYNC_ERRNOS.add(errno.WSAWOULDBLOCK)
|
|
|
|
_DEFAULT_CACHE_DURATION = 60 * 10 # 10 minutes according to the RFC.
|
|
|
|
# This value that can be used to disable RFC 6555 globally.
|
|
RFC6555_ENABLED = _HAS_IPV6
|
|
|
|
__all__ = ['RFC6555_ENABLED',
|
|
'create_connection',
|
|
'cache']
|
|
|
|
__version__ = '0.0.0'
|
|
__author__ = 'Seth Michael Larson'
|
|
__email__ = 'sethmichaellarson@protonmail.com'
|
|
__license__ = 'Apache-2.0'
|
|
|
|
|
|
class _RFC6555CacheManager(object):
|
|
def __init__(self):
|
|
self.validity_duration = _DEFAULT_CACHE_DURATION
|
|
self.enabled = True
|
|
self.entries = {}
|
|
|
|
def add_entry(self, address, family):
|
|
if self.enabled:
|
|
current_time = perf_counter()
|
|
|
|
# Don't over-write old entries to reset their expiry.
|
|
if address not in self.entries or self.entries[address][1] > current_time:
|
|
self.entries[address] = (family, current_time + self.validity_duration)
|
|
|
|
def get_entry(self, address):
|
|
if not self.enabled or address not in self.entries:
|
|
return None
|
|
|
|
family, expiry = self.entries[address]
|
|
if perf_counter() > expiry:
|
|
del self.entries[address]
|
|
return None
|
|
|
|
return family
|
|
|
|
|
|
cache = _RFC6555CacheManager()
|
|
|
|
|
|
class _RFC6555ConnectionManager(object):
|
|
def __init__(self, address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
|
|
self.address = address
|
|
self.timeout = timeout
|
|
self.source_address = source_address
|
|
|
|
self._error = None
|
|
self._selector = DefaultSelector()
|
|
self._sockets = []
|
|
self._start_time = None
|
|
|
|
def create_connection(self):
|
|
self._start_time = perf_counter()
|
|
|
|
host, port = self.address
|
|
addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
|
ret = self._connect_with_cached_family(addr_info)
|
|
|
|
# If it's a list, then these are the remaining values to try.
|
|
if isinstance(ret, list):
|
|
addr_info = ret
|
|
else:
|
|
cache.add_entry(self.address, ret.family)
|
|
return ret
|
|
|
|
# If we don't get any results back then just skip to the end.
|
|
if not addr_info:
|
|
raise socket.error('getaddrinfo returns an empty list')
|
|
|
|
sock = self._attempt_connect_with_addr_info(addr_info)
|
|
|
|
if sock:
|
|
cache.add_entry(self.address, sock.family)
|
|
return sock
|
|
elif self._error:
|
|
raise self._error
|
|
else:
|
|
raise socket.timeout()
|
|
|
|
def _attempt_connect_with_addr_info(self, addr_info):
|
|
sock = None
|
|
try:
|
|
for family, socktype, proto, _, sockaddr in addr_info:
|
|
self._create_socket(family, socktype, proto, sockaddr)
|
|
sock = self._wait_for_connection(False)
|
|
if sock:
|
|
break
|
|
if sock is None:
|
|
sock = self._wait_for_connection(True)
|
|
finally:
|
|
self._remove_all_sockets()
|
|
return sock
|
|
|
|
def _connect_with_cached_family(self, addr_info):
|
|
family = cache.get_entry(self.address)
|
|
if family is None:
|
|
return addr_info
|
|
|
|
is_family = []
|
|
not_family = []
|
|
|
|
for value in addr_info:
|
|
if value[0] == family:
|
|
is_family.append(value)
|
|
else:
|
|
not_family.append(value)
|
|
|
|
sock = self._attempt_connect_with_addr_info(is_family)
|
|
if sock is not None:
|
|
return sock
|
|
|
|
return not_family
|
|
|
|
def _create_socket(self, family, socktype, proto, sockaddr):
|
|
sock = None
|
|
try:
|
|
sock = socket.socket(family, socktype, proto)
|
|
|
|
# If we're using the 'default' socket timeout we have
|
|
# to set it to a real value here as this is the earliest
|
|
# opportunity to without pre-allocating a socket just for
|
|
# this purpose.
|
|
if self.timeout is socket._GLOBAL_DEFAULT_TIMEOUT:
|
|
self.timeout = sock.gettimeout()
|
|
|
|
if self.source_address:
|
|
sock.bind(self.source_address)
|
|
|
|
# Make the socket non-blocking so we can use our selector.
|
|
sock.settimeout(0.0)
|
|
|
|
if self._is_acceptable_errno(sock.connect_ex(sockaddr)):
|
|
self._selector.register(sock, EVENT_WRITE)
|
|
self._sockets.append(sock)
|
|
|
|
except _SOCKET_ERRORS as e:
|
|
self._error = e
|
|
if sock is not None:
|
|
_RFC6555ConnectionManager._close_socket(sock)
|
|
|
|
def _wait_for_connection(self, last_wait):
|
|
self._remove_all_errored_sockets()
|
|
|
|
# This is a safe-guard to make sure sock.gettimeout() is called in the
|
|
# case that the default socket timeout is used. If there are no
|
|
# sockets then we may not have called sock.gettimeout() yet.
|
|
if not self._sockets:
|
|
return None
|
|
|
|
# If this is the last time we're waiting for connections
|
|
# then we should wait until we should raise a timeout
|
|
# error, otherwise we should only wait >0.2 seconds as
|
|
# recommended by RFC 6555.
|
|
if last_wait:
|
|
if self.timeout is None:
|
|
select_timeout = None
|
|
else:
|
|
select_timeout = self._get_remaining_time()
|
|
else:
|
|
select_timeout = self._get_select_time()
|
|
|
|
# Wait for any socket to become writable as a sign of being connected.
|
|
for key, _ in self._selector.select(select_timeout):
|
|
sock = key.fileobj
|
|
|
|
if not self._is_socket_errored(sock):
|
|
|
|
# Restore the old proper timeout of the socket.
|
|
sock.settimeout(self.timeout)
|
|
|
|
# Remove it from this list to exempt the socket from cleanup.
|
|
self._sockets.remove(sock)
|
|
self._selector.unregister(sock)
|
|
return sock
|
|
|
|
return None
|
|
|
|
def _get_remaining_time(self):
|
|
if self.timeout is None:
|
|
return None
|
|
return max(self.timeout - (perf_counter() - self._start_time), 0.0)
|
|
|
|
def _get_select_time(self):
|
|
if self.timeout is None:
|
|
return 0.2
|
|
return min(0.2, self._get_remaining_time())
|
|
|
|
def _remove_all_errored_sockets(self):
|
|
socks = []
|
|
for sock in self._sockets:
|
|
if self._is_socket_errored(sock):
|
|
socks.append(sock)
|
|
for sock in socks:
|
|
self._selector.unregister(sock)
|
|
self._sockets.remove(sock)
|
|
_RFC6555ConnectionManager._close_socket(sock)
|
|
|
|
@staticmethod
|
|
def _close_socket(sock):
|
|
try:
|
|
sock.close()
|
|
except _SOCKET_ERRORS:
|
|
pass
|
|
|
|
def _is_acceptable_errno(self, errno):
|
|
if errno == 0 or errno in _ASYNC_ERRNOS:
|
|
return True
|
|
self._error = socket.error()
|
|
self._error.errno = errno
|
|
return False
|
|
|
|
def _is_socket_errored(self, sock):
|
|
errno = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
|
|
return not self._is_acceptable_errno(errno)
|
|
|
|
def _remove_all_sockets(self):
|
|
for sock in self._sockets:
|
|
self._selector.unregister(sock)
|
|
_RFC6555ConnectionManager._close_socket(sock)
|
|
self._sockets = []
|
|
|
|
|
|
def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
|
|
if RFC6555_ENABLED and _HAS_IPV6:
|
|
manager = _RFC6555ConnectionManager(address, timeout, source_address)
|
|
return manager.create_connection()
|
|
else:
|
|
# This code is the same as socket.create_connection() but is
|
|
# here to make sure the same code is used across all Python versions as
|
|
# the source_address parameter was added to socket.create_connection() in 3.2
|
|
# This segment of code is licensed under the Python Software Foundation License
|
|
# See LICENSE: https://github.com/python/cpython/blob/3.6/LICENSE
|
|
host, port = address
|
|
err = None
|
|
for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
|
|
af, socktype, proto, canonname, sa = res
|
|
sock = None
|
|
try:
|
|
sock = socket.socket(af, socktype, proto)
|
|
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
|
|
sock.settimeout(timeout)
|
|
if source_address:
|
|
sock.bind(source_address)
|
|
sock.connect(sa)
|
|
return sock
|
|
|
|
except socket.error as _:
|
|
err = _
|
|
if sock is not None:
|
|
sock.close()
|
|
|
|
if err is not None:
|
|
raise err
|
|
else:
|
|
raise socket.error("getaddrinfo returns an empty list")
|