This commit is contained in:
Matthew Wild 2020-09-30 09:46:30 +01:00
commit 4051f5e653
5 changed files with 434 additions and 85 deletions

View file

@ -16,11 +16,10 @@ local bor = bit.bor;
local bxor = bit.bxor;
local lshift = bit.lshift;
local rshift = bit.rshift;
local unpack = table.unpack or unpack; -- luacheck: ignore 113
local t_concat = table.concat;
local s_byte = string.byte;
local s_char= string.char;
local s_sub = string.sub;
local s_pack = string.pack; -- luacheck: ignore 143
local s_unpack = string.unpack; -- luacheck: ignore 143
@ -30,12 +29,12 @@ if not s_pack and softreq"struct" then
end
local function read_uint16be(str, pos)
local l1, l2 = s_byte(str, pos, pos+1);
local l1, l2 = str:byte(pos, pos+1);
return l1*256 + l2;
end
-- FIXME: this may lose precision
local function read_uint64be(str, pos)
local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7);
local l1, l2, l3, l4, l5, l6, l7, l8 = str:byte(pos, pos+7);
local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4;
local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8;
return h * 2^32 + l;
@ -63,9 +62,15 @@ end
if s_unpack then
function read_uint16be(str, pos)
if type(str) ~= "string" then
str, pos = str:sub(pos, pos+1), 1;
end
return s_unpack(">I2", str, pos);
end
function read_uint64be(str, pos)
if type(str) ~= "string" then
str, pos = str:sub(pos, pos+7), 1;
end
return s_unpack(">I8", str, pos);
end
end
@ -73,7 +78,7 @@ end
local function parse_frame_header(frame)
if #frame < 2 then return; end
local byte1, byte2 = s_byte(frame, 1, 2);
local byte1, byte2 = frame:byte(1, 2);
local result = {
FIN = band(byte1, 0x80) > 0;
RSV1 = band(byte1, 0x40) > 0;
@ -102,7 +107,7 @@ local function parse_frame_header(frame)
end
if result.MASK then
result.key = { s_byte(frame, length_bytes+3, length_bytes+6) };
result.key = { frame:byte(length_bytes+3, length_bytes+6) };
end
return result, header_length;
@ -121,7 +126,7 @@ local function apply_mask(str, key, from, to)
for i = from, to do
local key_index = counter%key_len + 1;
counter = counter + 1;
data[counter] = s_char(bxor(key[key_index], s_byte(str, i)));
data[counter] = s_char(bxor(key[key_index], str:byte(i)));
end
return t_concat(data);
end
@ -136,7 +141,7 @@ end
local function parse_frame(frame)
local result, pos = parse_frame_header(frame);
if result == nil or #frame < (pos + result.length) then return; end
if result == nil or #frame < (pos + result.length) then return nil, nil, result; end
result.data = parse_frame_body(frame, result, pos+1);
return result, pos + result.length;
end
@ -189,7 +194,7 @@ local function parse_close(data)
if #data >= 2 then
code = read_uint16be(data, 1);
if #data > 2 then
message = s_sub(data, 3);
message = data:sub(3);
end
end
return code, message

View file

@ -18,6 +18,7 @@ local contains_token = require "util.http".contains_token;
local portmanager = require "core.portmanager";
local sm_destroy_session = require"core.sessionmanager".destroy_session;
local log = module._log;
local dbuffer = require "util.dbuffer";
local websocket_frames = require"net.websocket.frames";
local parse_frame = websocket_frames.parse;
@ -27,6 +28,9 @@ local parse_close = websocket_frames.parse_close;
local t_concat = table.concat;
local stanza_size_limit = module:get_option_number("c2s_stanza_size_limit", 10 * 1024 * 1024);
local frame_buffer_limit = module:get_option_number("websocket_frame_buffer_limit", 2 * stanza_size_limit);
local frame_fragment_limit = module:get_option_number("websocket_frame_fragment_limit", 8);
local stream_close_timeout = module:get_option_number("c2s_close_timeout", 5);
local consider_websocket_secure = module:get_option_boolean("consider_websocket_secure");
local cross_domain = module:get_option_set("cross_domain_websocket", {});
@ -138,6 +142,65 @@ local function filter_open_close(data)
return data;
end
local function validate_frame(frame, max_length)
local opcode, length = frame.opcode, frame.length;
if max_length and length > max_length then
return false, 1009, "Payload too large";
end
-- Error cases
if frame.RSV1 or frame.RSV2 or frame.RSV3 then -- Reserved bits non zero
return false, 1002, "Reserved bits not zero";
end
if opcode == 0x8 and frame.data then -- close frame
if length == 1 then
return false, 1002, "Close frame with payload, but too short for status code";
elseif length >= 2 then
local status_code = parse_close(frame.data)
if status_code < 1000 then
return false, 1002, "Closed with invalid status code";
elseif ((status_code > 1003 and status_code < 1007) or status_code > 1011) and status_code < 3000 then
return false, 1002, "Closed with reserved status code";
end
end
end
if opcode >= 0x8 then
if length > 125 then -- Control frame with too much payload
return false, 1002, "Payload too large";
end
if not frame.FIN then -- Fragmented control frame
return false, 1002, "Fragmented control frame";
end
end
if (opcode > 0x2 and opcode < 0x8) or (opcode > 0xA) then
return false, 1002, "Reserved opcode";
end
-- Check opcode
if opcode == 0x2 then -- Binary frame
return false, 1003, "Only text frames are supported, RFC 7395 3.2";
elseif opcode == 0x8 then -- Close request
return false, 1000, "Goodbye";
end
-- Other (XMPP-specific) validity checks
if not frame.FIN then
return false, 1003, "Continuation frames are not supported, RFC 7395 3.3.3";
end
if opcode == 0x01 and frame.data and frame.data:byte(1, 1) ~= 60 then
return false, 1007, "Invalid payload start character, RFC 7395 3.3.3";
end
return true;
end
function handle_request(event)
local request, response = event.request, event.response;
local conn = response.conn;
@ -168,90 +231,40 @@ function handle_request(event)
conn:close();
end
local dataBuffer;
local function websocket_handle_error(session, code, message)
if code == 1009 then -- stanza size limit exceeded
-- we close the session, rather than the connection,
-- otherwise a resuming client will simply resend the
-- offending stanza
session:close({ condition = "policy-violation", text = "stanza too large" });
else
websocket_close(code, message);
end
end
local function handle_frame(frame)
local opcode = frame.opcode;
local length = frame.length;
module:log("debug", "Websocket received frame: opcode=%0x, %i bytes", frame.opcode, #frame.data);
-- Error cases
if frame.RSV1 or frame.RSV2 or frame.RSV3 then -- Reserved bits non zero
websocket_close(1002, "Reserved bits not zero");
return false;
-- Check frame makes sense
local frame_ok, err_status, err_text = validate_frame(frame, stanza_size_limit);
if not frame_ok then
return frame_ok, err_status, err_text;
end
if opcode == 0x8 then -- close frame
if length == 1 then
websocket_close(1002, "Close frame with payload, but too short for status code");
return false;
elseif length >= 2 then
local status_code = parse_close(frame.data)
if status_code < 1000 then
websocket_close(1002, "Closed with invalid status code");
return false;
elseif ((status_code > 1003 and status_code < 1007) or status_code > 1011) and status_code < 3000 then
websocket_close(1002, "Closed with reserved status code");
return false;
end
end
end
if opcode >= 0x8 then
if length > 125 then -- Control frame with too much payload
websocket_close(1002, "Payload too large");
return false;
end
if not frame.FIN then -- Fragmented control frame
websocket_close(1002, "Fragmented control frame");
return false;
end
end
if (opcode > 0x2 and opcode < 0x8) or (opcode > 0xA) then
websocket_close(1002, "Reserved opcode");
return false;
end
if opcode == 0x0 and not dataBuffer then
websocket_close(1002, "Unexpected continuation frame");
return false;
end
if (opcode == 0x1 or opcode == 0x2) and dataBuffer then
websocket_close(1002, "Continuation frame expected");
return false;
end
-- Valid cases
if opcode == 0x0 then -- Continuation frame
dataBuffer[#dataBuffer+1] = frame.data;
elseif opcode == 0x1 then -- Text frame
dataBuffer = {frame.data};
elseif opcode == 0x2 then -- Binary frame
websocket_close(1003, "Only text frames are supported");
return;
elseif opcode == 0x8 then -- Close request
websocket_close(1000, "Goodbye");
return;
elseif opcode == 0x9 then -- Ping frame
local opcode = frame.opcode;
if opcode == 0x9 then -- Ping frame
frame.opcode = 0xA;
frame.MASK = false; -- Clients send masked frames, servers don't, see #1484
conn:write(build_frame(frame));
return "";
elseif opcode == 0xA then -- Pong frame, MAY be sent unsolicited, eg as keepalive
return "";
else
elseif opcode ~= 0x1 then -- Not text frame (which is all we support)
log("warn", "Received frame with unsupported opcode %i", opcode);
return "";
end
if frame.FIN then
local data = t_concat(dataBuffer, "");
dataBuffer = nil;
return data;
end
return "";
return frame.data;
end
conn:setlistener(c2s_listener);
@ -269,19 +282,37 @@ function handle_request(event)
session.open_stream = session_open_stream;
session.close = session_close;
local frameBuffer = "";
local frameBuffer = dbuffer.new(frame_buffer_limit, frame_fragment_limit);
add_filter(session, "bytes/in", function(data)
if not frameBuffer:write(data) then
session.log("warn", "websocket frame buffer full - terminating session");
session:close({ condition = "resource-constraint", text = "frame buffer exceeded" });
return;
end
local cache = {};
frameBuffer = frameBuffer .. data;
local frame, length = parse_frame(frameBuffer);
local frame, length, partial = parse_frame(frameBuffer);
while frame do
frameBuffer = frameBuffer:sub(length + 1);
local result = handle_frame(frame);
if not result then return; end
frameBuffer:discard(length);
local result, err_status, err_text = handle_frame(frame);
if not result then
websocket_handle_error(session, err_status, err_text);
break;
end
cache[#cache+1] = filter_open_close(result);
frame, length = parse_frame(frameBuffer);
frame, length, partial = parse_frame(frameBuffer);
end
if partial then
-- The header of the next frame is already in the buffer, run
-- some early validation here
local frame_ok, err_status, err_text = validate_frame(partial, stanza_size_limit);
if not frame_ok then
websocket_handle_error(session, err_status, err_text);
end
end
return t_concat(cache, "");
end);

130
spec/util_dbuffer_spec.lua Normal file
View file

@ -0,0 +1,130 @@
local dbuffer = require "util.dbuffer";
describe("util.dbuffer", function ()
describe("#new", function ()
it("has a constructor", function ()
assert.Function(dbuffer.new);
end);
it("can be created", function ()
assert.truthy(dbuffer.new());
end);
it("won't create an empty buffer", function ()
assert.falsy(dbuffer.new(0));
end);
it("won't create a negatively sized buffer", function ()
assert.falsy(dbuffer.new(-1));
end);
end);
describe(":write", function ()
local b = dbuffer.new();
it("works", function ()
assert.truthy(b:write("hi"));
end);
end);
describe(":read", function ()
it("supports optional bytes parameter", function ()
-- should return the frontmost chunk
local b = dbuffer.new();
assert.truthy(b:write("hello"));
assert.truthy(b:write(" "));
assert.truthy(b:write("world"));
assert.equal("h", b:read(1));
assert.equal("ello", b:read());
assert.equal(" ", b:read());
assert.equal("world", b:read());
end);
end);
describe(":discard", function ()
local b = dbuffer.new();
it("works", function ()
assert.truthy(b:write("hello world"));
assert.truthy(b:discard(6));
assert.equal(5, b:length());
assert.equal("world", b:read(5));
end);
end);
describe(":collapse()", function ()
it("works on an empty buffer", function ()
local b = dbuffer.new();
b:collapse();
end);
end);
describe(":sub", function ()
-- Helper function to compare buffer:sub() with string:sub()
local s = "hello world";
local function test_sub(b, x, y)
local string_result, buffer_result = s:sub(x, y), b:sub(x, y);
assert.equals(string_result, buffer_result, ("buffer:sub(%d, %s) does not match string:sub()"):format(x, y and ("%d"):format(y) or "nil"));
end
it("works", function ()
local b = dbuffer.new();
assert.truthy(b:write("hello world"));
assert.equals("hello", b:sub(1, 5));
end);
it("works after discard", function ()
local b = dbuffer.new(256);
assert.truthy(b:write("foobar"));
assert.equals("foobar", b:sub(1, 6));
assert.truthy(b:discard(3)); -- consume "foo"
assert.equals("bar", b:sub(1, 3));
end);
it("supports optional end parameter", function ()
local b = dbuffer.new();
assert.truthy(b:write("hello world"));
assert.equals("hello world", b:sub(1));
assert.equals("world", b:sub(-5));
end);
it("is equivalent to string:sub", function ()
local b = dbuffer.new(11);
assert.truthy(b:write(s));
for i = -13, 13 do
for j = -13, 13 do
test_sub(b, i, j);
end
end
end);
end);
describe(":byte", function ()
-- Helper function to compare buffer:byte() with string:byte()
local s = "hello world"
local function test_byte(b, x, y)
local string_result, buffer_result = {s:byte(x, y)}, {b:byte(x, y)};
assert.same(string_result, buffer_result, ("buffer:byte(%d, %s) does not match string:byte()"):format(x, y and ("%d"):format(y) or "nil"));
end
it("is equivalent to string:byte", function ()
local b = dbuffer.new(11);
assert.truthy(b:write(s));
test_byte(b, 1);
test_byte(b, 3);
test_byte(b, -1);
test_byte(b, -3);
for i = -13, 13 do
for j = -13, 13 do
test_byte(b, i, j);
end
end
end);
it("works with characters > 127", function ()
local b = dbuffer.new();
b:write(string.char(0, 140));
local r = { b:byte(1, 2) };
assert.same({ 0, 140 }, r);
end);
it("works on an empty buffer", function ()
local b = dbuffer.new();
assert.equal("", b:sub(1,1));
end);
end);
end);

176
util/dbuffer.lua Normal file
View file

@ -0,0 +1,176 @@
local queue = require "util.queue";
local dbuffer_methods = {};
local dynamic_buffer_mt = { __index = dbuffer_methods };
function dbuffer_methods:write(data)
if self.max_size and #data + self._length > self.max_size then
return nil;
end
local ok = self.items:push(data);
if not ok then
self:collapse();
ok = self.items:push(data);
end
if not ok then
return nil;
end
self._length = self._length + #data;
return true;
end
function dbuffer_methods:read_chunk(requested_bytes)
local chunk, consumed = self.items:peek(), self.front_consumed;
if not chunk then return; end
local chunk_length = #chunk;
local remaining_chunk_length = chunk_length - consumed;
if not requested_bytes then
requested_bytes = remaining_chunk_length;
end
if remaining_chunk_length <= requested_bytes then
self.front_consumed = 0;
self._length = self._length - remaining_chunk_length;
self.items:pop();
assert(#chunk:sub(consumed + 1, -1) == remaining_chunk_length);
return chunk:sub(consumed + 1, -1), remaining_chunk_length;
end
local end_pos = consumed + requested_bytes;
self.front_consumed = end_pos;
self._length = self._length - requested_bytes;
assert(#chunk:sub(consumed + 1, end_pos) == requested_bytes);
return chunk:sub(consumed + 1, end_pos), requested_bytes;
end
function dbuffer_methods:read(requested_bytes)
local chunks;
if requested_bytes and requested_bytes > self._length then
return nil;
end
local chunk, read_bytes = self:read_chunk(requested_bytes);
if not requested_bytes then
return chunk;
elseif chunk then
requested_bytes = requested_bytes - read_bytes;
if requested_bytes == 0 then -- Already read everything we need
return chunk;
end
chunks = {};
else
return nil;
end
-- Need to keep reading more chunks
while chunk do
table.insert(chunks, chunk);
if requested_bytes > 0 then
chunk, read_bytes = self:read_chunk(requested_bytes);
requested_bytes = requested_bytes - read_bytes;
else
break;
end
end
return table.concat(chunks);
end
function dbuffer_methods:discard(requested_bytes)
if requested_bytes > self._length then
return nil;
end
local chunk, read_bytes = self:read_chunk(requested_bytes);
if chunk then
requested_bytes = requested_bytes - read_bytes;
if requested_bytes == 0 then -- Already read everything we need
return true;
end
else
return nil;
end
while chunk do
if requested_bytes > 0 then
chunk, read_bytes = self:read_chunk(requested_bytes);
requested_bytes = requested_bytes - read_bytes;
else
break;
end
end
return true;
end
function dbuffer_methods:sub(i, j)
if j == nil then
j = -1;
end
if j < 0 then
j = self._length + (j+1);
end
if i < 0 then
i = self._length + (i+1);
end
if i < 1 then
i = 1;
end
if j > self._length then
j = self._length;
end
if i > j then
return "";
end
self:collapse(j);
return self.items:peek():sub(self.front_consumed+1):sub(i, j);
end
function dbuffer_methods:byte(i, j)
i = i or 1;
j = j or i;
return string.byte(self:sub(i, j), 1, -1);
end
function dbuffer_methods:length()
return self._length;
end
dynamic_buffer_mt.__len = dbuffer_methods.length; -- support # operator
function dbuffer_methods:collapse(bytes)
bytes = bytes or self._length;
local front_chunk = self.items:peek();
if not front_chunk or #front_chunk - self.front_consumed >= bytes then
return;
end
local front_chunks = { front_chunk:sub(self.front_consumed+1) };
local front_bytes = #front_chunks[1];
while front_bytes < bytes do
self.items:pop();
local chunk = self.items:peek();
front_bytes = front_bytes + #chunk;
table.insert(front_chunks, chunk);
end
self.items:replace(table.concat(front_chunks));
self.front_consumed = 0;
end
local function new(max_size, max_chunks)
if max_size and max_size <= 0 then
return nil;
end
return setmetatable({
front_consumed = 0;
_length = 0;
max_size = max_size;
items = queue.new(max_chunks or 32);
}, dynamic_buffer_mt);
end
return {
new = new;
};

View file

@ -51,6 +51,13 @@ local function new(size, allow_wrapping)
end
return t[tail];
end;
replace = function (self, data)
if items == 0 then
return self:push(data);
end
t[tail] = data;
return true;
end;
items = function (self)
--luacheck: ignore 431/t
return function (t, pos)