mirror of
https://github.com/bjc/prosody.git
synced 2025-04-04 21:57:45 +03:00
Instead of using the string library, use methods from the passed object, which are assumed to be equivalent. This provides compatibility with objects from util.ringbuffer and util.dbuffer, for example.
224 lines
5.5 KiB
Lua
224 lines
5.5 KiB
Lua
-- Prosody IM
|
|
-- Copyright (C) 2012 Florian Zeitz
|
|
-- Copyright (C) 2014 Daurnimator
|
|
--
|
|
-- This project is MIT/X11 licensed. Please see the
|
|
-- COPYING file in the source package for more information.
|
|
--
|
|
|
|
local softreq = require "util.dependencies".softreq;
|
|
local random_bytes = require "util.random".bytes;
|
|
|
|
local bit = assert(softreq"bit" or softreq"bit32",
|
|
"No bit module found. See https://prosody.im/doc/depends#bitop");
|
|
local band = bit.band;
|
|
local bor = bit.bor;
|
|
local bxor = bit.bxor;
|
|
local lshift = bit.lshift;
|
|
local rshift = bit.rshift;
|
|
local unpack = table.unpack or unpack; -- luacheck: ignore 113
|
|
|
|
local t_concat = table.concat;
|
|
local s_char= string.char;
|
|
local s_pack = string.pack;
|
|
local s_unpack = string.unpack;
|
|
|
|
if not s_pack and softreq"struct" then
|
|
s_pack = softreq"struct".pack;
|
|
s_unpack = softreq"struct".unpack;
|
|
end
|
|
|
|
local function read_uint16be(str, pos)
|
|
local l1, l2 = str:byte(pos, pos+1);
|
|
return l1*256 + l2;
|
|
end
|
|
-- FIXME: this may lose precision
|
|
local function read_uint64be(str, pos)
|
|
local l1, l2, l3, l4, l5, l6, l7, l8 = str:byte(pos, pos+7);
|
|
local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4;
|
|
local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8;
|
|
return h * 2^32 + l;
|
|
end
|
|
local function pack_uint16be(x)
|
|
return s_char(rshift(x, 8), band(x, 0xFF));
|
|
end
|
|
local function get_byte(x, n)
|
|
return band(rshift(x, n), 0xFF);
|
|
end
|
|
local function pack_uint64be(x)
|
|
local h = band(x / 2^32, 2^32-1);
|
|
return s_char(get_byte(h, 24), get_byte(h, 16), get_byte(h, 8), band(h, 0xFF),
|
|
get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF));
|
|
end
|
|
|
|
if s_pack then
|
|
function pack_uint16be(x)
|
|
return s_pack(">I2", x);
|
|
end
|
|
function pack_uint64be(x)
|
|
return s_pack(">I8", x);
|
|
end
|
|
end
|
|
|
|
if s_unpack then
|
|
function read_uint16be(str, pos)
|
|
if type(str) ~= "string" then
|
|
str, pos = str:sub(pos, pos+1), 1;
|
|
end
|
|
return s_unpack(">I2", str, pos);
|
|
end
|
|
function read_uint64be(str, pos)
|
|
if type(str) ~= "string" then
|
|
str, pos = str:sub(pos, pos+7), 1;
|
|
end
|
|
return s_unpack(">I8", str, pos);
|
|
end
|
|
end
|
|
|
|
local function parse_frame_header(frame)
|
|
if #frame < 2 then return; end
|
|
|
|
local byte1, byte2 = frame:byte(1, 2);
|
|
local result = {
|
|
FIN = band(byte1, 0x80) > 0;
|
|
RSV1 = band(byte1, 0x40) > 0;
|
|
RSV2 = band(byte1, 0x20) > 0;
|
|
RSV3 = band(byte1, 0x10) > 0;
|
|
opcode = band(byte1, 0x0F);
|
|
|
|
MASK = band(byte2, 0x80) > 0;
|
|
length = band(byte2, 0x7F);
|
|
};
|
|
|
|
local length_bytes = 0;
|
|
if result.length == 126 then
|
|
length_bytes = 2;
|
|
elseif result.length == 127 then
|
|
length_bytes = 8;
|
|
end
|
|
|
|
local header_length = 2 + length_bytes + (result.MASK and 4 or 0);
|
|
if #frame < header_length then return; end
|
|
|
|
if length_bytes == 2 then
|
|
result.length = read_uint16be(frame, 3);
|
|
elseif length_bytes == 8 then
|
|
result.length = read_uint64be(frame, 3);
|
|
end
|
|
|
|
if result.MASK then
|
|
result.key = { frame:byte(length_bytes+3, length_bytes+6) };
|
|
end
|
|
|
|
return result, header_length;
|
|
end
|
|
|
|
-- XORs the string `str` with the array of bytes `key`
|
|
-- TODO: optimize
|
|
local function apply_mask(str, key, from, to)
|
|
from = from or 1
|
|
if from < 0 then from = #str + from + 1 end -- negative indices
|
|
to = to or #str
|
|
if to < 0 then to = #str + to + 1 end -- negative indices
|
|
local key_len = #key
|
|
local counter = 0;
|
|
local data = {};
|
|
for i = from, to do
|
|
local key_index = counter%key_len + 1;
|
|
counter = counter + 1;
|
|
data[counter] = s_char(bxor(key[key_index], str:byte(i)));
|
|
end
|
|
return t_concat(data);
|
|
end
|
|
|
|
local function parse_frame_body(frame, header, pos)
|
|
if header.MASK then
|
|
return apply_mask(frame, header.key, pos, pos + header.length - 1);
|
|
else
|
|
return frame:sub(pos, pos + header.length - 1);
|
|
end
|
|
end
|
|
|
|
local function parse_frame(frame)
|
|
local result, pos = parse_frame_header(frame);
|
|
if result == nil or #frame < (pos + result.length) then return; end
|
|
result.data = parse_frame_body(frame, result, pos+1);
|
|
return result, pos + result.length;
|
|
end
|
|
|
|
local function build_frame(desc)
|
|
local data = desc.data or "";
|
|
|
|
assert(desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF, "Invalid WebSocket opcode");
|
|
if desc.opcode >= 0x8 then
|
|
-- RFC 6455 5.5
|
|
assert(#data <= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less.");
|
|
end
|
|
|
|
local b1 = bor(desc.opcode,
|
|
desc.FIN and 0x80 or 0,
|
|
desc.RSV1 and 0x40 or 0,
|
|
desc.RSV2 and 0x20 or 0,
|
|
desc.RSV3 and 0x10 or 0);
|
|
|
|
local b2 = #data;
|
|
local length_extra;
|
|
if b2 <= 125 then -- 7-bit length
|
|
length_extra = "";
|
|
elseif b2 <= 0xFFFF then -- 2-byte length
|
|
b2 = 126;
|
|
length_extra = pack_uint16be(#data);
|
|
else -- 8-byte length
|
|
b2 = 127;
|
|
length_extra = pack_uint64be(#data);
|
|
end
|
|
|
|
local key = ""
|
|
if desc.MASK then
|
|
local key_a = desc.key
|
|
if key_a then
|
|
key = s_char(unpack(key_a, 1, 4));
|
|
else
|
|
key = random_bytes(4);
|
|
key_a = {key:byte(1,4)};
|
|
end
|
|
b2 = bor(b2, 0x80);
|
|
data = apply_mask(data, key_a);
|
|
end
|
|
|
|
return s_char(b1, b2) .. length_extra .. key .. data
|
|
end
|
|
|
|
local function parse_close(data)
|
|
local code, message
|
|
if #data >= 2 then
|
|
code = read_uint16be(data, 1);
|
|
if #data > 2 then
|
|
message = data:sub(3);
|
|
end
|
|
end
|
|
return code, message
|
|
end
|
|
|
|
local function build_close(code, message, mask)
|
|
local data = pack_uint16be(code);
|
|
if message then
|
|
assert(#message<=123, "Close reason must be <=123 bytes");
|
|
data = data .. message;
|
|
end
|
|
return build_frame({
|
|
opcode = 0x8;
|
|
FIN = true;
|
|
MASK = mask;
|
|
data = data;
|
|
});
|
|
end
|
|
|
|
return {
|
|
parse_header = parse_frame_header;
|
|
parse_body = parse_frame_body;
|
|
parse = parse_frame;
|
|
build = build_frame;
|
|
parse_close = parse_close;
|
|
build_close = build_close;
|
|
};
|