net.websocket.frames: Allow all methods to work on non-string objects

Instead of using the string library, use methods from the passed object,
which are assumed to be equivalent.

This provides compatibility with objects from util.ringbuffer and
util.dbuffer, for example.
This commit is contained in:
Matthew Wild 2020-09-17 13:00:19 +01:00
parent 095c4f8344
commit 3989ff2ddc

View file

@ -16,13 +16,12 @@ local bor = bit.bor;
local bxor = bit.bxor; local bxor = bit.bxor;
local lshift = bit.lshift; local lshift = bit.lshift;
local rshift = bit.rshift; local rshift = bit.rshift;
local unpack = table.unpack or unpack; -- luacheck: ignore 113
local t_concat = table.concat; local t_concat = table.concat;
local s_byte = string.byte;
local s_char= string.char; local s_char= string.char;
local s_sub = string.sub; local s_pack = string.pack;
local s_pack = string.pack; -- luacheck: ignore 143 local s_unpack = string.unpack;
local s_unpack = string.unpack; -- luacheck: ignore 143
if not s_pack and softreq"struct" then if not s_pack and softreq"struct" then
s_pack = softreq"struct".pack; s_pack = softreq"struct".pack;
@ -30,12 +29,12 @@ if not s_pack and softreq"struct" then
end end
local function read_uint16be(str, pos) local function read_uint16be(str, pos)
local l1, l2 = s_byte(str, pos, pos+1); local l1, l2 = str:byte(pos, pos+1);
return l1*256 + l2; return l1*256 + l2;
end end
-- FIXME: this may lose precision -- FIXME: this may lose precision
local function read_uint64be(str, pos) local function read_uint64be(str, pos)
local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7); local l1, l2, l3, l4, l5, l6, l7, l8 = str:byte(pos, pos+7);
local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4; local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4;
local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8; local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8;
return h * 2^32 + l; return h * 2^32 + l;
@ -63,9 +62,15 @@ end
if s_unpack then if s_unpack then
function read_uint16be(str, pos) function read_uint16be(str, pos)
if type(str) ~= "string" then
str, pos = str:sub(pos, pos+1), 1;
end
return s_unpack(">I2", str, pos); return s_unpack(">I2", str, pos);
end end
function read_uint64be(str, pos) function read_uint64be(str, pos)
if type(str) ~= "string" then
str, pos = str:sub(pos, pos+7), 1;
end
return s_unpack(">I8", str, pos); return s_unpack(">I8", str, pos);
end end
end end
@ -73,7 +78,7 @@ end
local function parse_frame_header(frame) local function parse_frame_header(frame)
if #frame < 2 then return; end if #frame < 2 then return; end
local byte1, byte2 = s_byte(frame, 1, 2); local byte1, byte2 = frame:byte(1, 2);
local result = { local result = {
FIN = band(byte1, 0x80) > 0; FIN = band(byte1, 0x80) > 0;
RSV1 = band(byte1, 0x40) > 0; RSV1 = band(byte1, 0x40) > 0;
@ -102,7 +107,7 @@ local function parse_frame_header(frame)
end end
if result.MASK then if result.MASK then
result.key = { s_byte(frame, length_bytes+3, length_bytes+6) }; result.key = { frame:byte(length_bytes+3, length_bytes+6) };
end end
return result, header_length; return result, header_length;
@ -121,7 +126,7 @@ local function apply_mask(str, key, from, to)
for i = from, to do for i = from, to do
local key_index = counter%key_len + 1; local key_index = counter%key_len + 1;
counter = counter + 1; counter = counter + 1;
data[counter] = s_char(bxor(key[key_index], s_byte(str, i))); data[counter] = s_char(bxor(key[key_index], str:byte(i)));
end end
return t_concat(data); return t_concat(data);
end end
@ -189,7 +194,7 @@ local function parse_close(data)
if #data >= 2 then if #data >= 2 then
code = read_uint16be(data, 1); code = read_uint16be(data, 1);
if #data > 2 then if #data > 2 then
message = s_sub(data, 3); message = data:sub(3);
end end
end end
return code, message return code, message