mirror of
https://github.com/bjc/prosody.git
synced 2025-04-03 21:27:38 +03:00
Comment from Matthew: This fixes a potential issue where the Prosody process gets blocked on sockets waiting for them to close. Unlike non-TLS sockets, closing a TLS socket sends layer 7 data, and this can cause problems for sockets which are in the process of being cleaned up. This depends on LuaSec changes which are not yet upstream. From Martijn's original email: So first my analysis of luasec. in ssl.c the socket is put into blocking mode right before calling SSL_shutdown() inside meth_destroy(). My best guess to why this is is because meth_destroy is linked to the __close and __gc methods, which can't exactly be called multiple times and luasec does want to make sure that a tls session is shutdown as clean as possible. I can't say I disagree with this reasoning and don't want to change this behaviour. My solution to this without changing the current behaviour is to introduce a shutdown() method. I am aware that this overlaps in a conflicting way with tcp's shutdown method, but it stays close to the OpenSSL name. This method calls SSL_shutdown() in the current (non)blocking mode of the underlying socket and returns a boolean whether or not the shutdown is completed (matching SSL_shutdown()'s 0 or 1 return values), and returns the familiar ssl_ioerror() strings on error with a false for completion. This error can then be used to determine if we have wantread/wantwrite to finalize things. Once meth_shutdown() has been called once a shutdown flag will be set, which indicates to meth_destroy() that the SSL_shutdown() has been handled by the application and it shouldn't be needed to set the socket to blocking mode. I've left the SSL_shutdown() call in the LSEC_STATE_CONNECTED to prevent TOCTOU if the application reaches a timeout for the shutdown code, which might allow SSL_shutdown() to clean up anyway at the last possible moment. Another thing I've changed to luasec is the call to socket_setblocking() right before calling close(2) in socket_destroy() in usocket.c. According to the latest POSIX[0]: Note that the requirement for close() on a socket to block for up to the current linger interval is not conditional on the O_NONBLOCK setting. Which I read to mean that removing O_NONBLOCK on the socket before close doesn't impact the behaviour and only causes noise in system call tracers. I didn't touch the windows bits of this, since I don't do windows. For the prosody side of things I've made the TLS shutdown bits resemble interface:onwritable(), and put it under a combined guard of self._tls and self.conn.shutdown. The self._tls bit is there to prevent getting stuck on this condition, and self.conn.shutdown is there to prevent the code being called by instances where the patched luasec isn't deployed. The destroy() method can be called from various places and is read by me as the "we give up" error path. To accommodate for these unexpected entrypoints I've added a single call to self.conn:shutdown() to prevent the socket being put into blocking mode. I have no expectations that there is any other use here. Same as previous, the self.conn.shutdown check is there to make sure it's not called on unpatched luasec deployments and self._tls is there to make sure we don't call shutdown() on tcp sockets. I wouldn't recommend logging of the conn:shutdown() error inside close(), since a lot of clients simply close the connection before SSL_shutdown() is done.
1326 lines
33 KiB
Lua
1326 lines
33 KiB
Lua
-- Prosody IM
|
|
-- Copyright (C) 2016-2018 Kim Alvefur
|
|
--
|
|
-- This project is MIT/X11 licensed. Please see the
|
|
-- COPYING file in the source package for more information.
|
|
--
|
|
|
|
|
|
local setmetatable = setmetatable;
|
|
local pcall = pcall;
|
|
local type = type;
|
|
local next = next;
|
|
local pairs = pairs;
|
|
local ipairs = ipairs;
|
|
local traceback = debug.traceback;
|
|
local logger = require "prosody.util.logger";
|
|
local log = logger.init("server_epoll");
|
|
local socket = require "socket";
|
|
local realtime = require "prosody.util.time".now;
|
|
local monotonic = require "prosody.util.time".monotonic;
|
|
local indexedbheap = require "prosody.util.indexedbheap";
|
|
local createtable = require "prosody.util.table".create;
|
|
local dbuffer = require "prosody.util.dbuffer";
|
|
local inet = require "prosody.util.net";
|
|
local inet_pton = inet.pton;
|
|
local _SOCKETINVALID = socket._SOCKETINVALID or -1;
|
|
local new_id = require "prosody.util.id".short;
|
|
local xpcall = require "prosody.util.xpcall".xpcall;
|
|
local sslconfig = require "prosody.util.sslconfig";
|
|
local tls_impl = require "prosody.net.tls_luasec";
|
|
local have_signal, signal = pcall(require, "prosody.util.signal");
|
|
|
|
local poller = require "prosody.util.poll"
|
|
local EEXIST = poller.EEXIST;
|
|
local ENOENT = poller.ENOENT;
|
|
|
|
-- systemd socket activation
|
|
local SD_LISTEN_FDS_START = 3;
|
|
local SD_LISTEN_FDS = tonumber(os.getenv("LISTEN_FDS")) or 0;
|
|
|
|
local inherited_sockets = setmetatable({}, {
|
|
__index = function(t, k)
|
|
local serv_mt = debug.getregistry()["tcp{server}"];
|
|
for i = 1, SD_LISTEN_FDS do
|
|
local serv = socket.tcp();
|
|
if serv:getfd() ~= _SOCKETINVALID then
|
|
-- If LuaSocket allocated a FD for then we can't really close it and it would leak.
|
|
log("error", "LuaSocket not compatible with socket activation. Upgrade LuaSocket or disable socket activation.");
|
|
setmetatable(t, nil);
|
|
break
|
|
end
|
|
serv:setfd(SD_LISTEN_FDS_START + i - 1);
|
|
debug.setmetatable(serv, serv_mt);
|
|
serv:settimeout(0);
|
|
local ip, port = serv:getsockname();
|
|
t[ip .. ":" .. port] = serv;
|
|
if ip == "0.0.0.0" then
|
|
-- LuaSocket treats '*' as an alias for '0.0.0.0'
|
|
t["*:" .. port] = serv;
|
|
end
|
|
end
|
|
|
|
-- Disable lazy-loading mechanism once performed
|
|
setmetatable(t, nil);
|
|
return t[k];
|
|
end;
|
|
});
|
|
|
|
local poll = assert(poller.new());
|
|
|
|
local _ENV = nil;
|
|
-- luacheck: std none
|
|
|
|
local default_config = { __index = {
|
|
-- If a connection is silent for this long, close it unless onreadtimeout says not to
|
|
read_timeout = 14 * 60;
|
|
|
|
-- How long to wait for a socket to become writable after queuing data to send
|
|
send_timeout = 180;
|
|
|
|
-- How long to wait for a socket to become writable after creation
|
|
connect_timeout = 20;
|
|
|
|
-- Some number possibly influencing how many pending connections can be accepted
|
|
tcp_backlog = 128;
|
|
|
|
-- If accepting a new incoming connection fails, wait this long before trying again
|
|
accept_retry_interval = 10;
|
|
|
|
-- If there is still more data to read from LuaSockets buffer, wait this long and read again
|
|
read_retry_delay = 1e-06;
|
|
|
|
-- Size of chunks to read from sockets
|
|
read_size = 8192;
|
|
|
|
-- Maximum size of send buffer, after which additional data is rejected
|
|
max_send_buffer_size = 32*1024*1024;
|
|
|
|
-- How many chunks (immutable strings) to keep in the send buffer
|
|
send_buffer_chunks = nil;
|
|
|
|
-- Maximum amount of data to send at once (to the TCP buffers), default based on /proc/sys/net/ipv4/tcp_wmem
|
|
max_send_chunk = 4*1024*1024;
|
|
|
|
-- Timeout used during between steps in TLS handshakes
|
|
ssl_handshake_timeout = 60;
|
|
|
|
-- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers)
|
|
max_wait = 86400;
|
|
min_wait = 0.001;
|
|
|
|
-- Enable extra noisy debug logging
|
|
verbose = false;
|
|
|
|
-- EXPERIMENTAL
|
|
-- Whether to kill connections in case of callback errors.
|
|
fatal_errors = false;
|
|
|
|
-- Or disable protection (like server_select) for potential performance gains
|
|
protect_listeners = true;
|
|
|
|
-- Attempt writes instantly
|
|
opportunistic_writes = false;
|
|
|
|
-- TCP Keepalives
|
|
tcp_keepalive = false; -- boolean | number
|
|
|
|
-- Whether to let the Nagle algorithm stay enabled
|
|
nagle = true;
|
|
|
|
-- Reuse write buffer tables
|
|
keep_buffers = true;
|
|
|
|
--- How long to wait after getting the shutdown signal before forcefully tearing down every socket
|
|
shutdown_deadline = 5;
|
|
|
|
-- TCP Fast Open
|
|
tcp_fastopen = false;
|
|
|
|
-- Defer accept until incoming data is available
|
|
tcp_defer_accept = false;
|
|
}};
|
|
local cfg = default_config.__index;
|
|
|
|
local fds = createtable(10, 0); -- FD -> conn
|
|
|
|
-- Timer and scheduling --
|
|
|
|
local timers = indexedbheap.create();
|
|
|
|
local function noop() end
|
|
|
|
-- Keep track of recently closed timers to avoid re-adding them
|
|
local closedtimers = {};
|
|
|
|
local function closetimer(id)
|
|
if timers:remove(id) then
|
|
closedtimers[id] = true;
|
|
end
|
|
end
|
|
|
|
local function reschedule(id, time)
|
|
time = monotonic() + time;
|
|
timers:reprioritize(id, time);
|
|
end
|
|
|
|
-- Add relative timer
|
|
local function addtimer(timeout, f, param)
|
|
local time = monotonic() + timeout;
|
|
if param ~= nil then
|
|
local timer_callback = f
|
|
function f(current_time, timer_id)
|
|
local t = timer_callback(current_time, timer_id, param)
|
|
return t;
|
|
end
|
|
end
|
|
local id = timers:insert(f, time);
|
|
return id;
|
|
end
|
|
|
|
-- Run callbacks of expired timers
|
|
-- Return time until next timeout
|
|
local function runtimers(next_delay, min_wait)
|
|
-- Any timers at all?
|
|
local elapsed = monotonic();
|
|
local now = realtime();
|
|
local peek = timers:peek();
|
|
local readd;
|
|
while peek do
|
|
|
|
if peek > elapsed then
|
|
break;
|
|
end
|
|
|
|
local _, timer, id = timers:pop();
|
|
local ok, ret = xpcall(timer, traceback, now, id);
|
|
if ok and type(ret) == "number" and not closedtimers[id] then
|
|
local next_time = elapsed+ret;
|
|
-- Delay insertion of timers to be re-added
|
|
-- so they don't get called again this tick
|
|
if readd then
|
|
readd[id] = { timer, next_time };
|
|
else
|
|
readd = { [id] = { timer, next_time } };
|
|
end
|
|
elseif not ok then
|
|
log("error", "Error in timer: %s", ret);
|
|
end
|
|
|
|
peek = timers:peek();
|
|
end
|
|
|
|
if readd then
|
|
for id, timer in pairs(readd) do
|
|
timers:insert(timer[1], timer[2], id);
|
|
end
|
|
peek = timers:peek();
|
|
end
|
|
|
|
if next(closedtimers) ~= nil then
|
|
closedtimers = {};
|
|
end
|
|
|
|
if peek == nil then
|
|
return next_delay;
|
|
else
|
|
next_delay = peek - elapsed;
|
|
end
|
|
|
|
if next_delay < min_wait then
|
|
return min_wait;
|
|
end
|
|
return next_delay;
|
|
end
|
|
|
|
-- Socket handler interface
|
|
|
|
local interface = {};
|
|
local interface_mt = { __index = interface };
|
|
|
|
function interface_mt:__tostring()
|
|
if self.sockname and self.peername then
|
|
return ("FD %d (%s, %d, %s, %d)"):format(self:getfd(), self.peername, self.peerport, self.sockname, self.sockport);
|
|
elseif self.sockname or self.peername then
|
|
return ("FD %d (%s, %d)"):format(self:getfd(), self.sockname or self.peername, self.sockport or self.peerport);
|
|
end
|
|
return ("FD %d"):format(self:getfd());
|
|
end
|
|
|
|
interface.log = log;
|
|
function interface:debug(msg, ...)
|
|
self.log("debug", msg, ...);
|
|
end
|
|
|
|
interface.noise = interface.debug;
|
|
function interface:noise(msg, ...)
|
|
if cfg.verbose then
|
|
return self:debug(msg, ...);
|
|
end
|
|
end
|
|
|
|
function interface:error(msg, ...)
|
|
self.log("error", msg, ...);
|
|
end
|
|
|
|
-- Replace the listener and tell the old one
|
|
function interface:setlistener(listeners, data)
|
|
self:on("detach");
|
|
self.listeners = listeners;
|
|
self:on("attach", data);
|
|
end
|
|
|
|
-- Call a listener callback
|
|
function interface:on(what, ...)
|
|
if not self.listeners then
|
|
self:error("Interface is missing listener callbacks");
|
|
return;
|
|
end
|
|
local listener = self.listeners["on"..what];
|
|
if not listener then
|
|
self:noise("Missing listener 'on%s'", what); -- uncomment for development and debugging
|
|
return;
|
|
end
|
|
if not cfg.protect_listeners then
|
|
return listener(self, ...);
|
|
end
|
|
local onerror = self.listeners.onerror or traceback;
|
|
local ok, err = xpcall(listener, onerror, self, ...);
|
|
if not ok then
|
|
if cfg.fatal_errors then
|
|
self:error("Closing due to error calling on%s: %s", what, err);
|
|
self:destroy();
|
|
else
|
|
self:error("Error calling on%s: %s", what, err);
|
|
end
|
|
return nil, err;
|
|
end
|
|
return err;
|
|
end
|
|
|
|
-- Allow this one to be overridden
|
|
function interface:onincoming(...)
|
|
return self:on("incoming", ...);
|
|
end
|
|
|
|
-- Return the file descriptor number
|
|
function interface:getfd()
|
|
if self.conn then
|
|
return self.conn:getfd();
|
|
end
|
|
return _SOCKETINVALID;
|
|
end
|
|
|
|
function interface:server()
|
|
return self._server or self;
|
|
end
|
|
|
|
-- Get IP address
|
|
function interface:ip()
|
|
return self.peername or self.sockname;
|
|
end
|
|
|
|
-- Get a port number, doesn't matter which
|
|
function interface:port()
|
|
return self.peerport or self.sockport;
|
|
end
|
|
|
|
-- Client-side port (usually a random high port)
|
|
function interface:clientport()
|
|
if self._server then
|
|
return self.peerport;
|
|
else
|
|
return self.sockport;
|
|
end
|
|
end
|
|
|
|
-- Get port on the server
|
|
function interface:serverport()
|
|
if self._server then
|
|
return self.sockport;
|
|
else
|
|
return self.peerport;
|
|
end
|
|
end
|
|
|
|
-- Return underlying socket
|
|
function interface:socket()
|
|
return self.conn;
|
|
end
|
|
|
|
function interface:set_mode(new_mode)
|
|
self.read_size = new_mode;
|
|
end
|
|
|
|
function interface:setoption(k, v)
|
|
-- LuaSec doesn't expose setoption :(
|
|
local ok, ret, err = pcall(self.conn.setoption, self.conn, k, v);
|
|
if not ok then
|
|
self:noise("Setting option %q = %q failed: %s", k, v, ret);
|
|
return ok, ret;
|
|
elseif not ret then
|
|
self:noise("Setting option %q = %q failed: %s", k, v, err);
|
|
return ret, err;
|
|
end
|
|
return ret;
|
|
end
|
|
|
|
-- Timeout for detecting dead or idle sockets
|
|
function interface:setreadtimeout(t)
|
|
if t == false then
|
|
if self._readtimeout then
|
|
closetimer(self._readtimeout);
|
|
self._readtimeout = nil;
|
|
end
|
|
return
|
|
end
|
|
t = t or cfg.read_timeout;
|
|
if self._readtimeout then
|
|
reschedule(self._readtimeout, t);
|
|
else
|
|
self._readtimeout = addtimer(t, function ()
|
|
if self:on("readtimeout") then
|
|
self:noise("Read timeout handled");
|
|
return cfg.read_timeout;
|
|
else
|
|
self:debug("Read timeout not handled, disconnecting");
|
|
self:on("disconnect", "read timeout");
|
|
self:destroy();
|
|
end
|
|
end);
|
|
end
|
|
end
|
|
|
|
-- Timeout for detecting dead sockets
|
|
function interface:setwritetimeout(t)
|
|
if t == false then
|
|
if self._writetimeout then
|
|
closetimer(self._writetimeout);
|
|
self._writetimeout = nil;
|
|
end
|
|
return
|
|
end
|
|
t = t or cfg.send_timeout;
|
|
if self._writetimeout then
|
|
reschedule(self._writetimeout, t);
|
|
else
|
|
self._writetimeout = addtimer(t, function ()
|
|
self:noise("Write timeout");
|
|
self:on("disconnect", self._connected and "write timeout" or "connection timeout");
|
|
self:destroy();
|
|
end);
|
|
end
|
|
end
|
|
|
|
function interface:add(r, w)
|
|
local fd = self:getfd();
|
|
if fd < 0 then
|
|
return nil, "invalid fd";
|
|
end
|
|
if r == nil then r = self._wantread; end
|
|
if w == nil then w = self._wantwrite; end
|
|
local ok, err, errno = poll:add(fd, r, w);
|
|
if not ok then
|
|
if errno == EEXIST then
|
|
self:debug("FD already registered in poller! (EEXIST)");
|
|
return self:set(r, w); -- So try to change its flags
|
|
end
|
|
self:debug("Could not register in poller: %s(%d)", err, errno);
|
|
return ok, err;
|
|
end
|
|
self._wantread, self._wantwrite = r, w;
|
|
fds[fd] = self;
|
|
self:noise("Registered in poller");
|
|
return true;
|
|
end
|
|
|
|
function interface:set(r, w)
|
|
local fd = self:getfd();
|
|
if fd < 0 then
|
|
return nil, "invalid fd";
|
|
end
|
|
if r == nil then r = self._wantread; end
|
|
if w == nil then w = self._wantwrite; end
|
|
if r == self._wantread and w == self._wantwrite then
|
|
return true
|
|
end
|
|
local ok, err, errno = poll:set(fd, r, w);
|
|
if not ok then
|
|
self:debug("Could not update poller state: %s(%d)", err, errno);
|
|
return ok, err;
|
|
end
|
|
self._wantread, self._wantwrite = r, w;
|
|
return true;
|
|
end
|
|
|
|
function interface:del()
|
|
local fd = self:getfd();
|
|
if fd < 0 then
|
|
return nil, "invalid fd";
|
|
end
|
|
if fds[fd] ~= self then
|
|
return nil, "unregistered fd";
|
|
end
|
|
local ok, err, errno = poll:del(fd);
|
|
if not ok and errno ~= ENOENT then
|
|
self:debug("Could not unregister: %s(%d)", err, errno);
|
|
return ok, err;
|
|
end
|
|
self._wantread, self._wantwrite = nil, nil;
|
|
fds[fd] = nil;
|
|
self:noise("Unregistered from poller");
|
|
return true;
|
|
end
|
|
|
|
function interface:setflags(r, w)
|
|
if not(self._wantread or self._wantwrite) then
|
|
if not(r or w) then
|
|
return true; -- no change
|
|
end
|
|
return self:add(r, w);
|
|
end
|
|
if not(r or w) then
|
|
return self:del();
|
|
end
|
|
return self:set(r, w);
|
|
end
|
|
|
|
-- Called when socket is readable
|
|
function interface:onreadable()
|
|
local data, err, partial = self.conn:receive(self.read_size or cfg.read_size);
|
|
if data then
|
|
self:onconnect();
|
|
self:onincoming(data);
|
|
else
|
|
if err == "wantread" then
|
|
self:set(true, nil);
|
|
err = "timeout";
|
|
elseif err == "wantwrite" then
|
|
self:set(nil, true);
|
|
self:setwritetimeout();
|
|
err = "timeout";
|
|
elseif err == "timeout" and not self._connected then
|
|
err = "connection timeout";
|
|
end
|
|
if partial and partial ~= "" then
|
|
self:onconnect();
|
|
self:onincoming(partial, err);
|
|
end
|
|
if err == "closed" and self._connected then
|
|
self:debug("Connection closed by remote");
|
|
self:on("disconnect", err);
|
|
self:destroy();
|
|
return;
|
|
elseif err ~= "timeout" then
|
|
self:debug("Read error, closing (%s)", err);
|
|
self:on("disconnect", err);
|
|
self:destroy();
|
|
return;
|
|
end
|
|
end
|
|
if not self.conn then return; end
|
|
if self._limit and (data or partial) then
|
|
local cost = self._limit * #(data or partial);
|
|
if cost > cfg.min_wait then
|
|
self:setreadtimeout(false);
|
|
self:pausefor(cost);
|
|
return;
|
|
end
|
|
end
|
|
if not self._wantread then return end
|
|
if self.conn:dirty() then
|
|
self:setreadtimeout(false);
|
|
self:pausefor(cfg.read_retry_delay);
|
|
else
|
|
self:setreadtimeout();
|
|
end
|
|
end
|
|
|
|
-- Called when socket is writable
|
|
function interface:onwritable()
|
|
self._writing = true; -- prevent reentrant writes etc
|
|
self:onconnect();
|
|
if not self.conn then return nil, "no-conn"; end -- could have been closed in onconnect
|
|
self:on("predrain");
|
|
local buffer = self.writebuffer or "";
|
|
-- Naming things ... s/data/slice/ ?
|
|
local data = buffer:sub(1, cfg.max_send_chunk);
|
|
local ok, err, partial = self.conn:send(data);
|
|
self._writable = ok;
|
|
if ok and #data < #buffer then
|
|
-- Sent the whole 'data' but there's more in the buffer
|
|
ok, err, partial = nil, "timeout", ok;
|
|
end
|
|
self:debug("Sent %d out of %d buffered bytes", ok and #data or partial or 0, #buffer);
|
|
if ok then -- all the data we had was sent successfully
|
|
self:set(nil, false);
|
|
if cfg.keep_buffers and type(buffer) == "table" then
|
|
buffer:discard(ok);
|
|
else -- string or don't keep buffers
|
|
self.writebuffer = nil;
|
|
end
|
|
self._writing = nil;
|
|
self:setwritetimeout(false);
|
|
self:ondrain(); -- Be aware of writes in ondrain
|
|
return ok;
|
|
elseif partial then
|
|
if type(buffer) == "table" then
|
|
buffer:discard(partial);
|
|
else
|
|
self.writebuffer = data:sub(partial + 1);
|
|
end
|
|
self:set(nil, true);
|
|
self:setwritetimeout();
|
|
end
|
|
self._writing = nil;
|
|
if err == "wantwrite" or err == "timeout" then
|
|
self:set(nil, true);
|
|
self:setwritetimeout();
|
|
elseif err == "wantread" then
|
|
self:set(true, nil);
|
|
self:setreadtimeout();
|
|
elseif err ~= "timeout" then
|
|
self:on("disconnect", err);
|
|
self:destroy();
|
|
return ok, err;
|
|
end
|
|
return true, err;
|
|
end
|
|
|
|
-- The write buffer has been successfully emptied
|
|
function interface:ondrain()
|
|
return self:on("drain");
|
|
end
|
|
|
|
-- Add data to write buffer and set flag for wanting to write
|
|
function interface:write(data)
|
|
local buffer = self.writebuffer;
|
|
-- (nil) -> save string
|
|
-- (string) -> convert to buffer (3 tables!)
|
|
-- (buffer) -> write to buffer
|
|
if not buffer then
|
|
self.writebuffer = data;
|
|
elseif type(buffer) == "string" then
|
|
local prev_buffer = buffer;
|
|
buffer = dbuffer.new(cfg.max_send_buffer_size, cfg.send_buffer_chunks);
|
|
self.writebuffer = buffer;
|
|
if prev_buffer then
|
|
-- TODO refactor, there's 3 copies of these lines
|
|
if not buffer:write(prev_buffer) then
|
|
if self._write_lock then
|
|
return false;
|
|
end
|
|
-- Try to flush buffer to make room
|
|
self:onwritable();
|
|
if not buffer:write(prev_buffer) then
|
|
self:on("disconnect", "no space left in buffer");
|
|
self:destroy();
|
|
return false;
|
|
end
|
|
end
|
|
end
|
|
if not buffer:write(data) then
|
|
if self._write_lock then
|
|
return false;
|
|
end
|
|
self:onwritable();
|
|
if not buffer:write(data) then
|
|
self:on("disconnect", "no space left in buffer");
|
|
self:destroy();
|
|
return false;
|
|
end
|
|
end
|
|
elseif not buffer:write(data) then
|
|
if self._write_lock then
|
|
return false;
|
|
end
|
|
self:onwritable();
|
|
if not buffer:write(data) then
|
|
self:on("disconnect", "no space left in buffer");
|
|
self:destroy();
|
|
return false;
|
|
end
|
|
end
|
|
if not self._write_lock and not self._writing then
|
|
if self._writable and cfg.opportunistic_writes and not self._opportunistic_write then
|
|
self._opportunistic_write = true;
|
|
local ret, err = self:onwritable();
|
|
self._opportunistic_write = nil;
|
|
return ret, err;
|
|
end
|
|
self:setwritetimeout();
|
|
self:set(nil, true);
|
|
end
|
|
return #data;
|
|
end
|
|
interface.send = interface.write;
|
|
|
|
-- Close, possibly after writing is done
|
|
function interface:close()
|
|
local status, err;
|
|
if self.writebuffer and #self.writebuffer ~= 0 then
|
|
self._connected = false;
|
|
self:set(false, true); -- Flush final buffer contents
|
|
self:setreadtimeout(false);
|
|
self:setwritetimeout();
|
|
self.write, self.send = noop, noop; -- No more writing
|
|
self:debug("Close after writing remaining buffered data");
|
|
self.ondrain = interface.close;
|
|
elseif self.conn.shutdown and self._tls then
|
|
status, err = self.conn:shutdown();
|
|
self.onreadable = interface.close;
|
|
self.onwritable = interface.close;
|
|
if err == nil then
|
|
if status == true then
|
|
self._tls = false;
|
|
end
|
|
return self:close();
|
|
elseif err == "wantread" then
|
|
self:set(true, nil);
|
|
self:setreadtimeout();
|
|
elseif err == "wantwrite" then
|
|
self:set(nil, true);
|
|
self:setwritetimeout();
|
|
else
|
|
self._tls = false;
|
|
end
|
|
else
|
|
self:debug("Closing now");
|
|
self.write, self.send = noop, noop;
|
|
self.close = noop;
|
|
self:on("disconnect");
|
|
self:destroy();
|
|
end
|
|
end
|
|
|
|
function interface:destroy()
|
|
-- make sure tls sockets aren't put in blocking mode
|
|
if self.conn.shutdown and self._tls then self.conn:shutdown(); end
|
|
self:del();
|
|
self:setwritetimeout(false);
|
|
self:setreadtimeout(false);
|
|
self.onreadable = noop;
|
|
self.onwritable = noop;
|
|
self.destroy = noop;
|
|
self.close = noop;
|
|
self.on = noop;
|
|
self.conn:close();
|
|
self.conn = nil;
|
|
end
|
|
|
|
function interface:ssl()
|
|
return self._tls;
|
|
end
|
|
|
|
function interface:set_sslctx(sslctx)
|
|
self._sslctx = sslctx;
|
|
end
|
|
|
|
function interface:sslctx()
|
|
return self.tls_ctx
|
|
end
|
|
|
|
function interface:ssl_info()
|
|
local sock = self.conn;
|
|
if not sock then return nil, "not-connected" end
|
|
if not sock.info then return nil, "not-implemented"; end
|
|
return sock:info();
|
|
end
|
|
|
|
function interface:ssl_peercertificate()
|
|
local sock = self.conn;
|
|
if not sock then return nil, "not-connected" end
|
|
if not sock.getpeercertificate then return nil, "not-implemented"; end
|
|
return sock:getpeercertificate();
|
|
end
|
|
|
|
function interface:ssl_peerverification()
|
|
local sock = self.conn;
|
|
if not sock then return nil, "not-connected" end
|
|
if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end
|
|
return sock:getpeerverification();
|
|
end
|
|
|
|
function interface:ssl_peerfinished()
|
|
local sock = self.conn;
|
|
if not sock then return nil, "not-connected" end
|
|
if not sock.getpeerfinished then return nil, "not-implemented"; end
|
|
return sock:getpeerfinished();
|
|
end
|
|
|
|
function interface:ssl_exportkeyingmaterial(label, len, context)
|
|
local sock = self.conn;
|
|
if not sock then return nil, "not-connected" end
|
|
if sock.exportkeyingmaterial then
|
|
return sock:exportkeyingmaterial(label, len, context);
|
|
end
|
|
end
|
|
|
|
|
|
function interface:starttls(tls_ctx)
|
|
if tls_ctx then self.tls_ctx = tls_ctx; end
|
|
self.starttls = false;
|
|
if self.writebuffer and #self.writebuffer ~= 0 then
|
|
self:debug("Start TLS after write");
|
|
self.ondrain = interface.starttls;
|
|
self:set(nil, true); -- make sure wantwrite is set
|
|
else
|
|
if self.ondrain == interface.starttls then
|
|
self.ondrain = nil;
|
|
end
|
|
self.onwritable = interface.inittls;
|
|
self.onreadable = interface.inittls;
|
|
self:set(true, true);
|
|
self:setreadtimeout(false);
|
|
self:setwritetimeout(cfg.ssl_handshake_timeout);
|
|
self:debug("Prepared to start TLS");
|
|
end
|
|
end
|
|
|
|
function interface:inittls(tls_ctx, now)
|
|
if self._tls then return end
|
|
if tls_ctx then self.tls_ctx = tls_ctx; end
|
|
self._tls = true;
|
|
self.starttls = false;
|
|
self:debug("Starting TLS now");
|
|
self:updatenames(); -- Can't getpeer/sockname after wrap()
|
|
local conn, err = self.tls_ctx:wrap(self.conn);
|
|
if not conn then
|
|
self:on("disconnect", err);
|
|
self:destroy();
|
|
return conn, err;
|
|
end
|
|
conn:settimeout(0);
|
|
self.conn = conn;
|
|
if conn.sni then
|
|
if self.servername then
|
|
conn:sni(self.servername);
|
|
elseif next(self.tls_ctx._sni_contexts) ~= nil then
|
|
conn:sni(self.tls_ctx._sni_contexts, true);
|
|
end
|
|
end
|
|
if self.extra and self.extra.tlsa and conn.settlsa then
|
|
-- TODO Error handling
|
|
if not conn:setdane(self.servername or self.extra.dane_hostname) then
|
|
self:debug("Could not enable DANE on connection");
|
|
else
|
|
self:debug("Enabling DANE with %d TLSA records", #self.extra.tlsa);
|
|
self:noise("DANE hostname is %q", self.servername or self.extra.dane_hostname);
|
|
for _, tlsa in ipairs(self.extra.tlsa) do
|
|
self:noise("TLSA: %q", tlsa);
|
|
conn:settlsa(tlsa.use, tlsa.select, tlsa.match, tlsa.data);
|
|
end
|
|
end
|
|
end
|
|
self:on("starttls");
|
|
self.ondrain = nil;
|
|
self.onwritable = interface.tlshandshake;
|
|
self.onreadable = interface.tlshandshake;
|
|
if now then
|
|
return self:tlshandshake()
|
|
end
|
|
self:setreadtimeout(false);
|
|
self:setwritetimeout(cfg.ssl_handshake_timeout);
|
|
self:set(true, true);
|
|
end
|
|
|
|
function interface:tlshandshake()
|
|
self:setreadtimeout(false);
|
|
self:noise("Continuing TLS handshake");
|
|
local ok, err = self.conn:dohandshake();
|
|
if ok then
|
|
local info = self.conn.info and self.conn:info();
|
|
if type(info) == "table" then
|
|
self:debug("TLS handshake complete (%s with %s)", info.protocol, info.cipher);
|
|
else
|
|
self:debug("TLS handshake complete");
|
|
end
|
|
self:setwritetimeout(false);
|
|
self.onwritable = nil;
|
|
self.onreadable = nil;
|
|
self:on("status", "ssl-handshake-complete");
|
|
self:set(true, true);
|
|
self:onconnect();
|
|
self:onreadable();
|
|
elseif err == "wantread" then
|
|
self:noise("TLS handshake to wait until readable");
|
|
self:set(true, false);
|
|
self:setwritetimeout(cfg.ssl_handshake_timeout);
|
|
elseif err == "wantwrite" then
|
|
self:noise("TLS handshake to wait until writable");
|
|
self:set(false, true);
|
|
self:setwritetimeout(cfg.ssl_handshake_timeout);
|
|
else
|
|
self:debug("TLS handshake error: %s", err);
|
|
self:on("disconnect", err);
|
|
self:destroy();
|
|
end
|
|
end
|
|
|
|
local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) -- luasocket object -> interface object
|
|
client:settimeout(0);
|
|
local conn_id = ("conn%s"):format(new_id());
|
|
local conn = setmetatable({
|
|
conn = client;
|
|
_server = server;
|
|
created = realtime();
|
|
listeners = listeners;
|
|
read_size = read_size or (server and server.read_size);
|
|
writebuffer = nil;
|
|
tls_ctx = tls_ctx or (server and server.tls_ctx);
|
|
tls_direct = server and server.tls_direct;
|
|
id = conn_id;
|
|
log = logger.init(conn_id);
|
|
extra = extra;
|
|
}, interface_mt);
|
|
|
|
if extra then
|
|
if extra.servername then
|
|
conn.servername = extra.servername;
|
|
end
|
|
end
|
|
|
|
return conn;
|
|
end
|
|
|
|
function interface:updatenames()
|
|
local conn = self.conn;
|
|
local ok, peername, peerport = pcall(conn.getpeername, conn);
|
|
if ok and peername then
|
|
self.peername, self.peerport = peername, peerport or 0;
|
|
end
|
|
local ok, sockname, sockport = pcall(conn.getsockname, conn);
|
|
if ok and sockname then
|
|
self.sockname, self.sockport = sockname, sockport or 0;
|
|
end
|
|
end
|
|
|
|
-- A server interface has new incoming connections waiting
|
|
-- This replaces the onreadable callback
|
|
function interface:onacceptable()
|
|
local conn, err = self.conn:accept();
|
|
if not conn then
|
|
self:debug("Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
|
|
self:pausefor(cfg.accept_retry_interval);
|
|
return;
|
|
end
|
|
local client = wrapsocket(conn, self, nil, self.listeners);
|
|
client:updatenames();
|
|
client:debug("New connection %s on server %s", client, self);
|
|
client:defaultoptions();
|
|
client._writable = cfg.opportunistic_writes;
|
|
if self.tls_direct then
|
|
client:add(true, true);
|
|
client:inittls(self.tls_ctx, true);
|
|
else
|
|
client:add(true, false);
|
|
client:onconnect();
|
|
client:onreadable();
|
|
end
|
|
end
|
|
|
|
-- Initialization for outgoing connections
|
|
function interface:init()
|
|
self:setwritetimeout(cfg.connect_timeout);
|
|
self:defaultoptions();
|
|
return self:add(true, true);
|
|
end
|
|
|
|
function interface:defaultoptions()
|
|
if cfg.nagle == false then
|
|
self:setoption("tcp-nodelay", true);
|
|
end
|
|
if cfg.tcp_keepalive then
|
|
self:setoption("keepalive", true);
|
|
if type(cfg.tcp_keepalive) == "number" then
|
|
self:setoption("tcp-keepidle", cfg.tcp_keepalive);
|
|
end
|
|
end
|
|
end
|
|
|
|
function interface:pause()
|
|
self:noise("Pause reading");
|
|
self:setreadtimeout(false);
|
|
return self:set(false);
|
|
end
|
|
|
|
function interface:resume()
|
|
self:noise("Resume reading");
|
|
self:setreadtimeout();
|
|
return self:set(true);
|
|
end
|
|
|
|
-- Pause connection for some time
|
|
function interface:pausefor(t)
|
|
self:noise("Pause for %fs", t);
|
|
if self._pausefor then
|
|
closetimer(self._pausefor);
|
|
self._pausefor = nil;
|
|
end
|
|
if t == false then return; end
|
|
self:set(false);
|
|
self._pausefor = addtimer(t, function ()
|
|
self._pausefor = nil;
|
|
self:set(true);
|
|
self:noise("Resuming after pause");
|
|
if self.conn and self.conn:dirty() then
|
|
self:noise("Have buffered incoming data to process");
|
|
self:onreadable();
|
|
end
|
|
end);
|
|
end
|
|
|
|
function interface:setlimit(Bps)
|
|
if Bps > 0 then
|
|
self._limit = 1/Bps;
|
|
else
|
|
self._limit = nil;
|
|
end
|
|
end
|
|
|
|
function interface:pause_writes()
|
|
if self._write_lock then
|
|
return
|
|
end
|
|
self:noise("Pause writes");
|
|
self._write_lock = true;
|
|
self:setwritetimeout(false);
|
|
self:set(nil, false);
|
|
end
|
|
|
|
function interface:resume_writes()
|
|
if not self._write_lock then
|
|
return
|
|
end
|
|
self:noise("Resume writes");
|
|
self._write_lock = nil;
|
|
if self.writebuffer and #self.writebuffer ~= 0 then
|
|
self:setwritetimeout();
|
|
self:set(nil, true);
|
|
end
|
|
end
|
|
|
|
-- Connected!
|
|
function interface:onconnect()
|
|
self._connected = true;
|
|
self:updatenames();
|
|
self:debug("Connected (%s)", self);
|
|
self.onconnect = noop;
|
|
self:on("connect");
|
|
end
|
|
|
|
local function wrapserver(conn, addr, port, listeners, config)
|
|
local server = setmetatable({
|
|
conn = conn;
|
|
created = realtime();
|
|
listeners = listeners;
|
|
read_size = config and config.read_size;
|
|
onreadable = interface.onacceptable;
|
|
tls_ctx = config and config.tls_ctx;
|
|
tls_direct = config and config.tls_direct;
|
|
hosts = config and config.sni_hosts;
|
|
sockname = addr;
|
|
sockport = port;
|
|
log = logger.init(("serv%s"):format(new_id()));
|
|
}, interface_mt);
|
|
server:debug("Server %s created", server);
|
|
if cfg.tcp_fastopen then
|
|
server:setoption("tcp-fastopen", cfg.tcp_fastopen);
|
|
end
|
|
if type(cfg.tcp_defer_accept) == "number" then
|
|
server:setoption("tcp-defer-accept", cfg.tcp_defer_accept);
|
|
end
|
|
server:add(true, false);
|
|
return server;
|
|
end
|
|
|
|
local function listen(addr, port, listeners, config)
|
|
local inherited = inherited_sockets[addr .. ":" .. port];
|
|
if inherited then
|
|
local conn = wrapserver(inherited, addr, port, listeners, config);
|
|
-- sockets created by systemd must not be :close() since we may not have
|
|
-- privileges to create them
|
|
conn.destroy = interface.del;
|
|
return conn;
|
|
end
|
|
local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
|
|
if not conn then return conn, err; end
|
|
conn:settimeout(0);
|
|
return wrapserver(conn, addr, port, listeners, config);
|
|
end
|
|
|
|
-- COMPAT
|
|
local function addserver(addr, port, listeners, read_size, tls_ctx)
|
|
return listen(addr, port, listeners, {
|
|
read_size = read_size;
|
|
tls_ctx = tls_ctx;
|
|
tls_direct = tls_ctx and true or false;
|
|
});
|
|
end
|
|
|
|
-- COMPAT
|
|
local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra)
|
|
local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra);
|
|
client:updatenames();
|
|
if not client.peername then
|
|
client.peername, client.peerport = addr, port;
|
|
end
|
|
local ok, err = client:init();
|
|
if not ok then return ok, err; end
|
|
if tls_ctx then
|
|
client:starttls(tls_ctx);
|
|
end
|
|
return client;
|
|
end
|
|
|
|
-- New outgoing TCP connection
|
|
local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra)
|
|
local create;
|
|
if not typ then
|
|
local n = inet_pton(addr);
|
|
if not n then return nil, "invalid-ip"; end
|
|
if #n == 16 then
|
|
typ = "tcp6";
|
|
else
|
|
typ = "tcp4";
|
|
end
|
|
end
|
|
if typ then
|
|
create = socket[typ];
|
|
end
|
|
if type(create) ~= "function" then
|
|
return nil, "invalid socket type";
|
|
end
|
|
local conn, err = create();
|
|
if not conn then return conn, err; end
|
|
local ok, err = conn:settimeout(0);
|
|
if not ok then return ok, err; end
|
|
local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
|
|
if cfg.tcp_fastopen then
|
|
client:setoption("tcp-fastopen-connect", 1);
|
|
end
|
|
local ok, err = conn:setpeername(addr, port);
|
|
if not ok and err ~= "timeout" then return ok, err; end
|
|
client:updatenames();
|
|
local ok, err = client:init();
|
|
if not client.peername then
|
|
-- otherwise not set until connected
|
|
client.peername, client.peerport = addr, port;
|
|
end
|
|
if not ok then return ok, err; end
|
|
client:debug("Client %s created", client);
|
|
if tls_ctx then
|
|
client:starttls(tls_ctx);
|
|
end
|
|
return client, conn;
|
|
end
|
|
|
|
local function watchfd(fd, onreadable, onwritable)
|
|
local conn = setmetatable({
|
|
conn = fd;
|
|
onreadable = onreadable;
|
|
onwritable = onwritable;
|
|
close = function (self)
|
|
self:del();
|
|
end
|
|
}, interface_mt);
|
|
if type(fd) == "number" then
|
|
conn.getfd = function ()
|
|
return fd;
|
|
end;
|
|
-- Otherwise it'll need to be something LuaSocket-compatible
|
|
end
|
|
conn.id = new_id();
|
|
conn.log = logger.init(("fdwatch%s"):format(conn.id));
|
|
conn:add(onreadable, onwritable);
|
|
return conn;
|
|
end;
|
|
|
|
-- Dump all data from one connection into another
|
|
local function link(from, to, read_size)
|
|
from:debug("Linking to %s", to.id);
|
|
function from:onincoming(data)
|
|
self:pause();
|
|
to:write(data);
|
|
end
|
|
function to:ondrain() -- luacheck: ignore 212/self
|
|
from:resume();
|
|
end
|
|
from:set_mode(read_size);
|
|
from:set(true, nil);
|
|
to:set(nil, true);
|
|
end
|
|
|
|
-- COMPAT
|
|
-- net.adns calls this but then replaces :send so this can be a noop
|
|
function interface:set_send(new_send) -- luacheck: ignore 212
|
|
end
|
|
|
|
-- Close all connections and servers
|
|
local function closeall()
|
|
for fd, conn in pairs(fds) do -- luacheck: ignore 213/fd
|
|
conn:close();
|
|
end
|
|
end
|
|
|
|
local quitting = nil;
|
|
|
|
-- Signal main loop about shutdown via above upvalue
|
|
local function setquitting(quit)
|
|
if quit then
|
|
quitting = "quitting";
|
|
closeall();
|
|
addtimer(1, function ()
|
|
if quitting then
|
|
closeall();
|
|
return 1;
|
|
end
|
|
end);
|
|
if cfg.shutdown_deadline then
|
|
addtimer(cfg.shutdown_deadline, function ()
|
|
if quitting then
|
|
for fd, conn in pairs(fds) do -- luacheck: ignore 213/fd
|
|
conn:destroy();
|
|
end
|
|
end
|
|
end);
|
|
end
|
|
else
|
|
quitting = nil;
|
|
end
|
|
end
|
|
|
|
local function loop_once()
|
|
runtimers(); -- Ignore return value because we only do this once
|
|
local fd, r, w = poll:wait(0);
|
|
if fd then
|
|
local conn = fds[fd];
|
|
if conn then
|
|
if r then
|
|
conn:onreadable();
|
|
end
|
|
if w then
|
|
conn:onwritable();
|
|
end
|
|
else
|
|
log("debug", "Removing unknown fd %d", fd);
|
|
poll:del(fd);
|
|
end
|
|
else
|
|
return fd, r;
|
|
end
|
|
end
|
|
|
|
-- Main loop
|
|
local function loop(once)
|
|
if once then
|
|
return loop_once();
|
|
end
|
|
|
|
local t = 0;
|
|
while not quitting do
|
|
local fd, r, w = poll:wait(t);
|
|
if fd then
|
|
t = 0;
|
|
local conn = fds[fd];
|
|
if conn then
|
|
if r then
|
|
conn:onreadable();
|
|
end
|
|
if w then
|
|
conn:onwritable();
|
|
end
|
|
else
|
|
log("debug", "Removing unknown fd %d", fd);
|
|
poll:del(fd);
|
|
end
|
|
elseif r == "timeout" then
|
|
t = runtimers(cfg.max_wait, cfg.min_wait);
|
|
elseif r ~= "signal" then
|
|
log("debug", "epoll_wait error: %s[%d]", r, w);
|
|
end
|
|
end
|
|
return quitting;
|
|
end
|
|
|
|
local hook_signal;
|
|
if have_signal and signal.signalfd then
|
|
local function dispatch(self)
|
|
return self:on("signal", self.conn:read());
|
|
end
|
|
|
|
function hook_signal(signum, cb)
|
|
local sigfd = signal.signalfd(signum);
|
|
if not sigfd then
|
|
log("error", "Could not hook signal %d", signum);
|
|
return nil, "failed";
|
|
end
|
|
local watch = watchfd(sigfd, dispatch);
|
|
watch.listeners = { onsignal = cb };
|
|
watch.close = nil; -- revert to default
|
|
watch:noise("Signal handler %d ready", signum);
|
|
return watch;
|
|
end
|
|
end
|
|
|
|
return {
|
|
get_backend = function () return "epoll"; end;
|
|
addserver = addserver;
|
|
addclient = addclient;
|
|
add_task = addtimer;
|
|
timer = {
|
|
-- API-compatible with util.timer
|
|
add_task = addtimer;
|
|
stop = closetimer;
|
|
reschedule = reschedule;
|
|
to_absolute_time = function (t)
|
|
return t-monotonic()+realtime();
|
|
end;
|
|
};
|
|
listen = listen;
|
|
loop = loop;
|
|
closeall = closeall;
|
|
setquitting = setquitting;
|
|
wrapclient = wrapclient;
|
|
wrapserver = wrapserver;
|
|
watchfd = watchfd;
|
|
link = link;
|
|
set_config = function (newconfig)
|
|
cfg = setmetatable(newconfig, default_config);
|
|
end;
|
|
hook_signal = hook_signal;
|
|
|
|
tls_builder = function(basedir)
|
|
return sslconfig._new(tls_impl.new_context, basedir)
|
|
end,
|
|
|
|
-- libevent emulation
|
|
event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
|
|
addevent = function (fd, mode, callback)
|
|
log("warn", "Using deprecated libevent emulation, please update code to use watchfd API instead");
|
|
local function onevent(self)
|
|
local ret = self:callback();
|
|
if ret == -1 then
|
|
self:set(false, false);
|
|
elseif ret then
|
|
self:set(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
|
|
end
|
|
end
|
|
|
|
local conn = setmetatable({
|
|
getfd = function () return fd; end;
|
|
callback = callback;
|
|
onreadable = onevent;
|
|
onwritable = onevent;
|
|
close = function (self)
|
|
self:del();
|
|
fds[fd] = nil;
|
|
end;
|
|
}, interface_mt);
|
|
conn.id = conn:getfd();
|
|
conn.log = logger.init(("fdwatch%d"):format(conn.id));
|
|
local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
|
|
if not ok then return ok, err; end
|
|
return conn;
|
|
end;
|
|
};
|