centralize format_address

This commit is contained in:
Peter Bieringer 2025-03-13 21:47:44 +01:00
parent b0d649f8b9
commit e22fbe282b
2 changed files with 23 additions and 19 deletions

View file

@ -58,19 +58,7 @@ elif sys.platform == "win32":
# IPv4 (host, port) and IPv6 (host, port, flowinfo, scopeid) # IPv4 (host, port) and IPv6 (host, port, flowinfo, scopeid)
ADDRESS_TYPE = Union[Tuple[Union[str, bytes, bytearray], int], ADDRESS_TYPE = utils.ADDRESS_TYPE
Tuple[str, int, int, int]]
def format_address(address: ADDRESS_TYPE) -> str:
host, port, *_ = address
if not isinstance(host, str):
raise NotImplementedError("Unsupported address format: %r" %
(address,))
if host.find(":") == -1:
return "%s:%d" % (host, port)
else:
return "[%s]:%d" % (host, port)
class ParallelHTTPServer(socketserver.ThreadingMixIn, class ParallelHTTPServer(socketserver.ThreadingMixIn,
@ -321,20 +309,20 @@ def serve(configuration: config.Configuration,
try: try:
getaddrinfo = socket.getaddrinfo(address_port[0], address_port[1], 0, socket.SOCK_STREAM, socket.IPPROTO_TCP) getaddrinfo = socket.getaddrinfo(address_port[0], address_port[1], 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)
except OSError as e: except OSError as e:
logger.warning("cannot retrieve IPv4 or IPv6 address of '%s': %s" % (format_address(address_port), e)) logger.warning("cannot retrieve IPv4 or IPv6 address of '%s': %s" % (utils.format_address(address_port), e))
continue continue
logger.debug("getaddrinfo of '%s': %s" % (format_address(address_port), getaddrinfo)) logger.debug("getaddrinfo of '%s': %s" % (utils.format_address(address_port), getaddrinfo))
for (address_family, socket_kind, socket_proto, socket_flags, socket_address) in getaddrinfo: for (address_family, socket_kind, socket_proto, socket_flags, socket_address) in getaddrinfo:
logger.debug("try to create server socket on '%s'" % (format_address(socket_address))) logger.debug("try to create server socket on '%s'" % (utils.format_address(socket_address)))
try: try:
server = server_class(configuration, address_family, (socket_address[0], socket_address[1]), RequestHandler) server = server_class(configuration, address_family, (socket_address[0], socket_address[1]), RequestHandler)
except OSError as e: except OSError as e:
logger.warning("cannot create server socket on '%s': %s" % (format_address(socket_address), e)) logger.warning("cannot create server socket on '%s': %s" % (utils.format_address(socket_address), e))
continue continue
servers[server.socket] = server servers[server.socket] = server
server.set_app(application) server.set_app(application)
logger.info("Listening on %r%s", logger.info("Listening on %r%s",
format_address(server.server_address), utils.format_address(server.server_address),
" with SSL" if use_ssl else "") " with SSL" if use_ssl else "")
if not servers: if not servers:
raise RuntimeError("No servers started") raise RuntimeError("No servers started")

View file

@ -20,7 +20,7 @@
import ssl import ssl
import sys import sys
from importlib import import_module, metadata from importlib import import_module, metadata
from typing import Callable, Sequence, Type, TypeVar, Union from typing import Callable, Sequence, Tuple, Type, TypeVar, Union
from radicale import config from radicale import config
from radicale.log import logger from radicale.log import logger
@ -36,6 +36,11 @@ RADICALE_MODULES: Sequence[str] = ("radicale", "vobject", "passlib", "defusedxml
"pam") "pam")
# IPv4 (host, port) and IPv6 (host, port, flowinfo, scopeid)
ADDRESS_TYPE = Union[Tuple[Union[str, bytes, bytearray], int],
Tuple[str, int, int, int]]
def load_plugin(internal_types: Sequence[str], module_name: str, def load_plugin(internal_types: Sequence[str], module_name: str,
class_name: str, base_class: Type[_T_co], class_name: str, base_class: Type[_T_co],
configuration: "config.Configuration") -> _T_co: configuration: "config.Configuration") -> _T_co:
@ -74,6 +79,17 @@ def packages_version():
return " ".join(versions) return " ".join(versions)
def format_address(address: ADDRESS_TYPE) -> str:
host, port, *_ = address
if not isinstance(host, str):
raise NotImplementedError("Unsupported address format: %r" %
(address,))
if host.find(":") == -1:
return "%s:%d" % (host, port)
else:
return "[%s]:%d" % (host, port)
def ssl_context_options_by_protocol(protocol: str, ssl_context_options): def ssl_context_options_by_protocol(protocol: str, ssl_context_options):
logger.debug("SSL protocol string: '%s' and current SSL context options: '0x%x'", protocol, ssl_context_options) logger.debug("SSL protocol string: '%s' and current SSL context options: '0x%x'", protocol, ssl_context_options)
# disable any protocol by default # disable any protocol by default