Use forking for internal server when available

This commit is contained in:
Unrud 2018-08-18 12:56:41 +02:00
parent ddd99a5329
commit 30a9ecc06b
2 changed files with 60 additions and 27 deletions

View file

@ -22,6 +22,7 @@ Radicale WSGI server.
"""
import contextlib
import multiprocessing
import os
import select
import signal
@ -29,16 +30,20 @@ import socket
import socketserver
import ssl
import sys
import threading
import wsgiref.simple_server
from urllib.parse import unquote
from radicale import Application
from radicale.log import logger
if hasattr(socketserver, "ForkingMixIn"):
ParallelizationMixIn = socketserver.ForkingMixIn
else:
ParallelizationMixIn = socketserver.ThreadingMixIn
class HTTPServer(wsgiref.simple_server.WSGIServer):
"""HTTP server."""
class ParallelHTTPServer(ParallelizationMixIn,
wsgiref.simple_server.WSGIServer):
# These class attributes must be set before creating instance
client_timeout = None
@ -59,7 +64,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
if self.max_connections:
self.connections_guard = threading.BoundedSemaphore(
self.connections_guard = multiprocessing.BoundedSemaphore(
self.max_connections)
else:
# use dummy context manager
@ -75,10 +80,14 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
def get_request(self):
# Set timeout for client
_socket, address = super().get_request()
socket_, address = super().get_request()
if self.client_timeout:
_socket.settimeout(self.client_timeout)
return _socket, address
socket_.settimeout(self.client_timeout)
return socket_, address
def finish_request(self, request, client_address):
with self.connections_guard:
return super().finish_request(request, client_address)
def handle_error(self, request, client_address):
if issubclass(sys.exc_info()[0], socket.timeout):
@ -88,8 +97,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
sys.exc_info()[1], exc_info=True)
class HTTPSServer(HTTPServer):
"""HTTPS server."""
class ParallelHTTPSServer(ParallelHTTPServer):
# These class attributes must be set before creating instance
certificate = None
@ -98,9 +106,11 @@ class HTTPSServer(HTTPServer):
ciphers = None
certificate_authority = None
def __init__(self, address, handler):
def __init__(self, address, handler, bind_and_activate=True):
"""Create server by wrapping HTTP socket in an SSL socket."""
super().__init__(address, handler, bind_and_activate=False)
# Do not bind and activate, as we change the socket
super().__init__(address, handler, False)
self.socket = ssl.wrap_socket(
self.socket, self.key, self.certificate, server_side=True,
@ -110,18 +120,15 @@ class HTTPSServer(HTTPServer):
ssl_version=self.protocol, ciphers=self.ciphers,
do_handshake_on_connect=False)
self.server_bind()
self.server_activate()
if bind_and_activate:
try:
self.server_bind()
self.server_activate()
except BaseException:
self.server_close()
raise
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
def process_request_thread(self, request, client_address):
with self.connections_guard:
return super().process_request_thread(request, client_address)
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
def process_request_thread(self, request, client_address):
def finish_request(self, request, client_address):
try:
try:
request.do_handshake()
@ -135,8 +142,7 @@ class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
finally:
self.shutdown_request(request)
return
with self.connections_guard:
return super().process_request_thread(request, client_address)
return super().finish_request(request, client_address)
class ServerHandler(wsgiref.simple_server.ServerHandler):
@ -197,7 +203,7 @@ def serve(configuration):
# Create collection servers
servers = {}
if configuration.getboolean("server", "ssl"):
server_class = ThreadedHTTPSServer
server_class = ParallelHTTPSServer
server_class.certificate = configuration.get("server", "certificate")
server_class.key = configuration.get("server", "key")
server_class.certificate_authority = configuration.get(
@ -216,7 +222,7 @@ def serve(configuration):
raise RuntimeError("Failed to read SSL %s %r: %s" %
(name, filename, e)) from e
else:
server_class = ThreadedHTTPServer
server_class = ParallelHTTPServer
server_class.client_timeout = configuration.getint("server", "timeout")
server_class.max_connections = configuration.getint(
"server", "max_connections")