mirror of
https://github.com/bjc/prosody.git
synced 2025-04-01 20:27:39 +03:00
net.resolvers.service: Honour record 'weight' when picking SRV targets
#NotHappyEyeballs
This commit is contained in:
parent
d1c2e34e61
commit
d26811f5e5
2 changed files with 309 additions and 13 deletions
|
@ -2,23 +2,78 @@ local adns = require "net.adns";
|
|||
local basic = require "net.resolvers.basic";
|
||||
local inet_pton = require "util.net".pton;
|
||||
local idna_to_ascii = require "util.encodings".idna.to_ascii;
|
||||
local unpack = table.unpack or unpack; -- luacheck: ignore 113
|
||||
|
||||
local methods = {};
|
||||
local resolver_mt = { __index = methods };
|
||||
|
||||
local function new_target_selector(rrset)
|
||||
local rr_count = rrset and #rrset;
|
||||
if not rr_count or rr_count == 0 then
|
||||
rrset = nil;
|
||||
else
|
||||
table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end);
|
||||
end
|
||||
local rrset_pos = 1;
|
||||
local priority_bucket, bucket_total_weight, bucket_len, bucket_used;
|
||||
return function ()
|
||||
if not rrset then return; end
|
||||
|
||||
if not priority_bucket or bucket_used >= bucket_len then
|
||||
if rrset_pos > rr_count then return; end -- Used up all records
|
||||
|
||||
-- Going to start on a new priority now. Gather up all the next
|
||||
-- records with the same priority and add them to priority_bucket
|
||||
priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0;
|
||||
local current_priority;
|
||||
repeat
|
||||
local curr_record = rrset[rrset_pos].srv;
|
||||
if not current_priority then
|
||||
current_priority = curr_record.priority;
|
||||
elseif current_priority ~= curr_record.priority then
|
||||
break;
|
||||
end
|
||||
table.insert(priority_bucket, curr_record);
|
||||
bucket_total_weight = bucket_total_weight + curr_record.weight;
|
||||
bucket_len = bucket_len + 1;
|
||||
rrset_pos = rrset_pos + 1;
|
||||
until rrset_pos > rr_count;
|
||||
end
|
||||
|
||||
bucket_used = bucket_used + 1;
|
||||
local n, running_total = math.random(0, bucket_total_weight), 0;
|
||||
local target_record;
|
||||
for i = 1, bucket_len do
|
||||
local candidate = priority_bucket[i];
|
||||
if candidate then
|
||||
running_total = running_total + candidate.weight;
|
||||
if running_total >= n then
|
||||
target_record = candidate;
|
||||
bucket_total_weight = bucket_total_weight - candidate.weight;
|
||||
priority_bucket[i] = nil;
|
||||
break;
|
||||
end
|
||||
end
|
||||
end
|
||||
return target_record;
|
||||
end;
|
||||
end
|
||||
|
||||
-- Find the next target to connect to, and
|
||||
-- pass it to cb()
|
||||
function methods:next(cb)
|
||||
if self.targets then
|
||||
if not self.resolver then
|
||||
if #self.targets == 0 then
|
||||
if self.resolver or self._get_next_target then
|
||||
if not self.resolver then -- Do we have a basic resolver currently?
|
||||
-- We don't, so fetch a new SRV target, create a new basic resolver for it
|
||||
local next_srv_target = self._get_next_target and self._get_next_target();
|
||||
if not next_srv_target then
|
||||
-- No more SRV targets left
|
||||
cb(nil);
|
||||
return;
|
||||
end
|
||||
local next_target = table.remove(self.targets, 1);
|
||||
self.resolver = basic.new(unpack(next_target, 1, 4));
|
||||
-- Create a new basic resolver for this SRV target
|
||||
self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra);
|
||||
end
|
||||
-- Look up the next (basic) target from the current target's resolver
|
||||
self.resolver:next(function (...)
|
||||
if self.resolver then
|
||||
self.last_error = self.resolver.last_error;
|
||||
|
@ -31,6 +86,9 @@ function methods:next(cb)
|
|||
end
|
||||
end);
|
||||
return;
|
||||
elseif self.in_progress then
|
||||
cb(nil);
|
||||
return;
|
||||
end
|
||||
|
||||
if not self.hostname then
|
||||
|
@ -39,9 +97,9 @@ function methods:next(cb)
|
|||
return;
|
||||
end
|
||||
|
||||
local targets = {};
|
||||
self.in_progress = true;
|
||||
|
||||
local function ready()
|
||||
self.targets = targets;
|
||||
self:next(cb);
|
||||
end
|
||||
|
||||
|
@ -63,7 +121,7 @@ function methods:next(cb)
|
|||
|
||||
if #answer == 0 then
|
||||
if self.extra and self.extra.default_port then
|
||||
table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra });
|
||||
self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra);
|
||||
else
|
||||
self.last_error = "zero SRV records found";
|
||||
end
|
||||
|
@ -77,10 +135,7 @@ function methods:next(cb)
|
|||
return;
|
||||
end
|
||||
|
||||
table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end);
|
||||
for _, record in ipairs(answer) do
|
||||
table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra });
|
||||
end
|
||||
self._get_next_target = new_target_selector(answer);
|
||||
else
|
||||
self.last_error = err;
|
||||
end
|
||||
|
|
241
spec/net_resolvers_service_spec.lua
Normal file
241
spec/net_resolvers_service_spec.lua
Normal file
|
@ -0,0 +1,241 @@
|
|||
local set = require "util.set";
|
||||
|
||||
insulate("net.resolvers.service", function ()
|
||||
local adns = {
|
||||
resolver = function ()
|
||||
return {
|
||||
lookup = function (_, cb, qname, qtype, qclass)
|
||||
if qname == "_xmpp-server._tcp.example.com"
|
||||
and (qtype or "SRV") == "SRV"
|
||||
and (qclass or "IN") == "IN" then
|
||||
cb({
|
||||
{ -- 60+35+60
|
||||
srv = { target = "xmpp0-a.example.com", port = 5228, priority = 0, weight = 60 };
|
||||
};
|
||||
{
|
||||
srv = { target = "xmpp0-b.example.com", port = 5216, priority = 0, weight = 35 };
|
||||
};
|
||||
{
|
||||
srv = { target = "xmpp0-c.example.com", port = 5200, priority = 0, weight = 0 };
|
||||
};
|
||||
{
|
||||
srv = { target = "xmpp0-d.example.com", port = 5256, priority = 0, weight = 120 };
|
||||
};
|
||||
|
||||
{
|
||||
srv = { target = "xmpp1-a.example.com", port = 5273, priority = 1, weight = 30 };
|
||||
};
|
||||
{
|
||||
srv = { target = "xmpp1-b.example.com", port = 5274, priority = 1, weight = 30 };
|
||||
};
|
||||
|
||||
{
|
||||
srv = { target = "xmpp2.example.com", port = 5275, priority = 2, weight = 0 };
|
||||
};
|
||||
});
|
||||
elseif qname == "_xmpp-server._tcp.single.example.com"
|
||||
and (qtype or "SRV") == "SRV"
|
||||
and (qclass or "IN") == "IN" then
|
||||
cb({
|
||||
{
|
||||
srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 };
|
||||
};
|
||||
});
|
||||
elseif qname == "_xmpp-server._tcp.half.example.com"
|
||||
and (qtype or "SRV") == "SRV"
|
||||
and (qclass or "IN") == "IN" then
|
||||
cb({
|
||||
{
|
||||
srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 };
|
||||
};
|
||||
{
|
||||
srv = { target = "xmpp0-b.example.com", port = 5270, priority = 0, weight = 1 };
|
||||
};
|
||||
});
|
||||
elseif qtype == "A" then
|
||||
local l = qname:match("%-(%a)%.example.com$") or "1";
|
||||
local d = ("%d"):format(l:byte())
|
||||
cb({
|
||||
{
|
||||
a = "127.0.0."..d;
|
||||
};
|
||||
});
|
||||
elseif qtype == "AAAA" then
|
||||
local l = qname:match("%-(%a)%.example.com$") or "1";
|
||||
local d = ("%04d"):format(l:byte())
|
||||
cb({
|
||||
{
|
||||
aaaa = "fdeb:9619:649e:c7d9::"..d;
|
||||
};
|
||||
});
|
||||
else
|
||||
cb(nil);
|
||||
end
|
||||
end;
|
||||
};
|
||||
end;
|
||||
};
|
||||
package.loaded["net.adns"] = mock(adns);
|
||||
local resolver = require "net.resolvers.service";
|
||||
math.randomseed(os.time());
|
||||
it("works for 99% of deployments", function ()
|
||||
-- Most deployments only have a single SRV record, let's make
|
||||
-- sure that works okay
|
||||
|
||||
local expected_targets = set.new({
|
||||
-- xmpp0-a
|
||||
"tcp4 127.0.0.97 5269";
|
||||
"tcp6 fdeb:9619:649e:c7d9::0097 5269";
|
||||
});
|
||||
local received_targets = set.new({});
|
||||
|
||||
local r = resolver.new("single.example.com", "xmpp-server");
|
||||
local done = false;
|
||||
local function handle_target(...)
|
||||
if ... == nil then
|
||||
done = true;
|
||||
-- No more targets
|
||||
return;
|
||||
end
|
||||
received_targets:add(table.concat({ ... }, " ", 1, 3));
|
||||
end
|
||||
r:next(handle_target);
|
||||
while not done do
|
||||
r:next(handle_target);
|
||||
end
|
||||
|
||||
-- We should have received all expected targets, and no unexpected
|
||||
-- ones:
|
||||
assert.truthy(set.xor(received_targets, expected_targets):empty());
|
||||
end);
|
||||
|
||||
it("supports A/AAAA fallback", function ()
|
||||
-- Many deployments don't have any SRV records, so we should
|
||||
-- fall back to A/AAAA records instead when that is the case
|
||||
|
||||
local expected_targets = set.new({
|
||||
-- xmpp0-a
|
||||
"tcp4 127.0.0.97 5269";
|
||||
"tcp6 fdeb:9619:649e:c7d9::0097 5269";
|
||||
});
|
||||
local received_targets = set.new({});
|
||||
|
||||
local r = resolver.new("xmpp0-a.example.com", "xmpp-server", "tcp", { default_port = 5269 });
|
||||
local done = false;
|
||||
local function handle_target(...)
|
||||
if ... == nil then
|
||||
done = true;
|
||||
-- No more targets
|
||||
return;
|
||||
end
|
||||
received_targets:add(table.concat({ ... }, " ", 1, 3));
|
||||
end
|
||||
r:next(handle_target);
|
||||
while not done do
|
||||
r:next(handle_target);
|
||||
end
|
||||
|
||||
-- We should have received all expected targets, and no unexpected
|
||||
-- ones:
|
||||
assert.truthy(set.xor(received_targets, expected_targets):empty());
|
||||
end);
|
||||
|
||||
|
||||
it("works", function ()
|
||||
local expected_targets = set.new({
|
||||
-- xmpp0-a
|
||||
"tcp4 127.0.0.97 5228";
|
||||
"tcp6 fdeb:9619:649e:c7d9::0097 5228";
|
||||
"tcp4 127.0.0.97 5273";
|
||||
"tcp6 fdeb:9619:649e:c7d9::0097 5273";
|
||||
|
||||
-- xmpp0-b
|
||||
"tcp4 127.0.0.98 5274";
|
||||
"tcp6 fdeb:9619:649e:c7d9::0098 5274";
|
||||
"tcp4 127.0.0.98 5216";
|
||||
"tcp6 fdeb:9619:649e:c7d9::0098 5216";
|
||||
|
||||
-- xmpp0-c
|
||||
"tcp4 127.0.0.99 5200";
|
||||
"tcp6 fdeb:9619:649e:c7d9::0099 5200";
|
||||
|
||||
-- xmpp0-d
|
||||
"tcp4 127.0.0.100 5256";
|
||||
"tcp6 fdeb:9619:649e:c7d9::0100 5256";
|
||||
|
||||
-- xmpp2
|
||||
"tcp4 127.0.0.49 5275";
|
||||
"tcp6 fdeb:9619:649e:c7d9::0049 5275";
|
||||
|
||||
});
|
||||
local received_targets = set.new({});
|
||||
|
||||
local r = resolver.new("example.com", "xmpp-server");
|
||||
local done = false;
|
||||
local function handle_target(...)
|
||||
if ... == nil then
|
||||
done = true;
|
||||
-- No more targets
|
||||
return;
|
||||
end
|
||||
received_targets:add(table.concat({ ... }, " ", 1, 3));
|
||||
end
|
||||
r:next(handle_target);
|
||||
while not done do
|
||||
r:next(handle_target);
|
||||
end
|
||||
|
||||
-- We should have received all expected targets, and no unexpected
|
||||
-- ones:
|
||||
assert.truthy(set.xor(received_targets, expected_targets):empty());
|
||||
end);
|
||||
|
||||
it("balances across weights correctly #slow", function ()
|
||||
-- This mimics many repeated connections to 'example.com' (mock
|
||||
-- records defined above), and records the port number of the
|
||||
-- first target. Therefore it (should) only return priority
|
||||
-- 0 records, and the input data is constructed such that the
|
||||
-- last two digits of the port number represent the percentage
|
||||
-- of times that record should (on average) be picked first.
|
||||
|
||||
-- To prevent random test failures, we test across a handful
|
||||
-- of fixed (randomly selected) seeds.
|
||||
for _, seed in ipairs({ 8401877, 3943829, 7830992 }) do
|
||||
math.randomseed(seed);
|
||||
|
||||
local results = {};
|
||||
local function run()
|
||||
local run_results = {};
|
||||
local r = resolver.new("example.com", "xmpp-server");
|
||||
local function record_target(...)
|
||||
if ... == nil then
|
||||
-- No more targets
|
||||
return;
|
||||
end
|
||||
run_results = { ... };
|
||||
end
|
||||
r:next(record_target);
|
||||
return run_results[3];
|
||||
end
|
||||
|
||||
for _ = 1, 1000 do
|
||||
local port = run();
|
||||
results[port] = (results[port] or 0) + 1;
|
||||
end
|
||||
|
||||
local ports = {};
|
||||
for port in pairs(results) do
|
||||
table.insert(ports, port);
|
||||
end
|
||||
table.sort(ports);
|
||||
for _, port in ipairs(ports) do
|
||||
--print("PORT", port, tostring((results[port]/1000) * 100).."% hits (expected "..tostring(port-5200).."%)");
|
||||
local hit_pct = (results[port]/1000) * 100;
|
||||
local expected_pct = port - 5200;
|
||||
--print(hit_pct, expected_pct, math.abs(hit_pct - expected_pct));
|
||||
assert.is_true(math.abs(hit_pct - expected_pct) < 5);
|
||||
end
|
||||
--print("---");
|
||||
end
|
||||
end);
|
||||
end);
|
Loading…
Add table
Add a link
Reference in a new issue