mirror of
https://github.com/bjc/prosody.git
synced 2025-04-04 13:47:41 +03:00
919 lines
26 KiB
Lua
919 lines
26 KiB
Lua
-- Prosody IM
|
|
-- This file is included with Prosody IM. It has modifications,
|
|
-- which are hereby placed in the public domain.
|
|
|
|
-- public domain 20080404 lua@ztact.com
|
|
|
|
|
|
-- todo: quick (default) header generation
|
|
-- todo: nxdomain, error handling
|
|
-- todo: cache results of encodeName
|
|
|
|
|
|
-- reference: http://tools.ietf.org/html/rfc1035
|
|
-- reference: http://tools.ietf.org/html/rfc1876 (LOC)
|
|
|
|
|
|
require 'socket'
|
|
local ztact = require 'util.ztact'
|
|
local require = require
|
|
local os = os;
|
|
|
|
local coroutine, io, math, socket, string, table =
|
|
coroutine, io, math, socket, string, table
|
|
|
|
local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
|
|
ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack
|
|
|
|
local get, set = ztact.get, ztact.set
|
|
|
|
|
|
-------------------------------------------------- module dns
|
|
module ('dns')
|
|
local dns = _M;
|
|
|
|
|
|
-- dns type & class codes ------------------------------ dns type & class codes
|
|
|
|
|
|
local append = table.insert
|
|
|
|
|
|
local function highbyte (i) -- - - - - - - - - - - - - - - - - - - highbyte
|
|
return (i-(i%0x100))/0x100
|
|
end
|
|
|
|
|
|
local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment
|
|
local a = {}
|
|
for i,s in pairs (t) do a[i] = s a[s] = s a[string.lower (s)] = s end
|
|
return a
|
|
end
|
|
|
|
|
|
local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode
|
|
local code = {}
|
|
for i,s in pairs (t) do
|
|
local word = string.char (highbyte (i), i %0x100)
|
|
code[i] = word
|
|
code[s] = word
|
|
code[string.lower (s)] = word
|
|
end
|
|
return code
|
|
end
|
|
|
|
|
|
dns.types = {
|
|
'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
|
|
'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
|
|
[ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV',
|
|
[252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }
|
|
|
|
|
|
dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }
|
|
|
|
|
|
dns.type = augment (dns.types)
|
|
dns.class = augment (dns.classes)
|
|
dns.typecode = encode (dns.types)
|
|
dns.classcode = encode (dns.classes)
|
|
|
|
|
|
|
|
local function standardize (qname, qtype, qclass) -- - - - - - - standardize
|
|
if string.byte (qname, -1) ~= 0x2E then qname = qname..'.' end
|
|
qname = string.lower (qname)
|
|
return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']
|
|
end
|
|
|
|
|
|
local function prune (rrs, time, soft) -- - - - - - - - - - - - - - - prune
|
|
|
|
time = time or socket.gettime ()
|
|
for i,rr in pairs (rrs) do
|
|
|
|
if rr.tod then
|
|
-- rr.tod = rr.tod - 50 -- accelerated decripitude
|
|
rr.ttl = math.floor (rr.tod - time)
|
|
if rr.ttl <= 0 then
|
|
table.remove(rrs, i);
|
|
return prune(rrs, time, soft); -- Re-iterate
|
|
end
|
|
|
|
elseif soft == 'soft' then -- What is this? I forget!
|
|
assert (rr.ttl == 0)
|
|
rrs[i] = nil
|
|
end end end
|
|
|
|
|
|
-- metatables & co. ------------------------------------------ metatables & co.
|
|
|
|
|
|
local resolver = {}
|
|
resolver.__index = resolver
|
|
|
|
|
|
local SRV_tostring
|
|
|
|
|
|
local rr_metatable = {} -- - - - - - - - - - - - - - - - - - - rr_metatable
|
|
function rr_metatable.__tostring (rr)
|
|
local s0 = string.format (
|
|
'%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name )
|
|
local s1 = ''
|
|
if rr.type == 'A' then s1 = ' '..rr.a
|
|
elseif rr.type == 'MX' then
|
|
s1 = string.format (' %2i %s', rr.pref, rr.mx)
|
|
elseif rr.type == 'CNAME' then s1 = ' '..rr.cname
|
|
elseif rr.type == 'LOC' then s1 = ' '..resolver.LOC_tostring (rr)
|
|
elseif rr.type == 'NS' then s1 = ' '..rr.ns
|
|
elseif rr.type == 'SRV' then s1 = ' '..SRV_tostring (rr)
|
|
elseif rr.type == 'TXT' then s1 = ' '..rr.txt
|
|
else s1 = ' <UNKNOWN RDATA TYPE>' end
|
|
return s0..s1
|
|
end
|
|
|
|
|
|
local rrs_metatable = {} -- - - - - - - - - - - - - - - - - - rrs_metatable
|
|
function rrs_metatable.__tostring (rrs)
|
|
local t = {}
|
|
for i,rr in pairs (rrs) do append (t, tostring (rr)..'\n') end
|
|
return table.concat (t)
|
|
end
|
|
|
|
|
|
local cache_metatable = {} -- - - - - - - - - - - - - - - - cache_metatable
|
|
function cache_metatable.__tostring (cache)
|
|
local time = socket.gettime ()
|
|
local t = {}
|
|
for class,types in pairs (cache) do
|
|
for type,names in pairs (types) do
|
|
for name,rrs in pairs (names) do
|
|
prune (rrs, time)
|
|
append (t, tostring (rrs)) end end end
|
|
return table.concat (t)
|
|
end
|
|
|
|
|
|
function resolver:new () -- - - - - - - - - - - - - - - - - - - - - resolver
|
|
local r = { active = {}, cache = {}, unsorted = {} }
|
|
setmetatable (r, resolver)
|
|
setmetatable (r.cache, cache_metatable)
|
|
setmetatable (r.unsorted, { __mode = 'kv' })
|
|
return r
|
|
end
|
|
|
|
|
|
-- packet layer -------------------------------------------------- packet layer
|
|
|
|
|
|
function dns.random (...) -- - - - - - - - - - - - - - - - - - - dns.random
|
|
math.randomseed (10000*socket.gettime ())
|
|
dns.random = math.random
|
|
return dns.random (...)
|
|
end
|
|
|
|
|
|
local function encodeHeader (o) -- - - - - - - - - - - - - - - encodeHeader
|
|
|
|
o = o or {}
|
|
|
|
o.id = o.id or -- 16b (random) id
|
|
dns.random (0, 0xffff)
|
|
|
|
o.rd = o.rd or 1 -- 1b 1 recursion desired
|
|
o.tc = o.tc or 0 -- 1b 1 truncated response
|
|
o.aa = o.aa or 0 -- 1b 1 authoritative response
|
|
o.opcode = o.opcode or 0 -- 4b 0 query
|
|
-- 1 inverse query
|
|
-- 2 server status request
|
|
-- 3-15 reserved
|
|
o.qr = o.qr or 0 -- 1b 0 query, 1 response
|
|
|
|
o.rcode = o.rcode or 0 -- 4b 0 no error
|
|
-- 1 format error
|
|
-- 2 server failure
|
|
-- 3 name error
|
|
-- 4 not implemented
|
|
-- 5 refused
|
|
-- 6-15 reserved
|
|
o.z = o.z or 0 -- 3b 0 resvered
|
|
o.ra = o.ra or 0 -- 1b 1 recursion available
|
|
|
|
o.qdcount = o.qdcount or 1 -- 16b number of question RRs
|
|
o.ancount = o.ancount or 0 -- 16b number of answers RRs
|
|
o.nscount = o.nscount or 0 -- 16b number of nameservers RRs
|
|
o.arcount = o.arcount or 0 -- 16b number of additional RRs
|
|
|
|
-- string.char() rounds, so prevent roundup with -0.4999
|
|
local header = string.char (
|
|
highbyte (o.id), o.id %0x100,
|
|
o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
|
|
o.rcode + 16*o.z + 128*o.ra,
|
|
highbyte (o.qdcount), o.qdcount %0x100,
|
|
highbyte (o.ancount), o.ancount %0x100,
|
|
highbyte (o.nscount), o.nscount %0x100,
|
|
highbyte (o.arcount), o.arcount %0x100 )
|
|
|
|
return header, o.id
|
|
end
|
|
|
|
|
|
local function encodeName (name) -- - - - - - - - - - - - - - - - encodeName
|
|
local t = {}
|
|
for part in string.gmatch (name, '[^.]+') do
|
|
append (t, string.char (string.len (part)))
|
|
append (t, part)
|
|
end
|
|
append (t, string.char (0))
|
|
return table.concat (t)
|
|
end
|
|
|
|
|
|
local function encodeQuestion (qname, qtype, qclass) -- - - - encodeQuestion
|
|
qname = encodeName (qname)
|
|
qtype = dns.typecode[qtype or 'a']
|
|
qclass = dns.classcode[qclass or 'in']
|
|
return qname..qtype..qclass;
|
|
end
|
|
|
|
|
|
function resolver:byte (len) -- - - - - - - - - - - - - - - - - - - - - byte
|
|
len = len or 1
|
|
local offset = self.offset
|
|
local last = offset + len - 1
|
|
if last > #self.packet then
|
|
error (string.format ('out of bounds: %i>%i', last, #self.packet)) end
|
|
self.offset = offset + len
|
|
return string.byte (self.packet, offset, last)
|
|
end
|
|
|
|
|
|
function resolver:word () -- - - - - - - - - - - - - - - - - - - - - - word
|
|
local b1, b2 = self:byte (2)
|
|
return 0x100*b1 + b2
|
|
end
|
|
|
|
|
|
function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword
|
|
local b1, b2, b3, b4 = self:byte (4)
|
|
--print ('dword', b1, b2, b3, b4)
|
|
return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
|
|
end
|
|
|
|
|
|
function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub
|
|
len = len or 1
|
|
local s = string.sub (self.packet, self.offset, self.offset + len - 1)
|
|
self.offset = self.offset + len
|
|
return s
|
|
end
|
|
|
|
|
|
function resolver:header (force) -- - - - - - - - - - - - - - - - - - header
|
|
|
|
local id = self:word ()
|
|
--print (string.format (':header id %x', id))
|
|
if not self.active[id] and not force then return nil end
|
|
|
|
local h = { id = id }
|
|
|
|
local b1, b2 = self:byte (2)
|
|
|
|
h.rd = b1 %2
|
|
h.tc = b1 /2%2
|
|
h.aa = b1 /4%2
|
|
h.opcode = b1 /8%16
|
|
h.qr = b1 /128
|
|
|
|
h.rcode = b2 %16
|
|
h.z = b2 /16%8
|
|
h.ra = b2 /128
|
|
|
|
h.qdcount = self:word ()
|
|
h.ancount = self:word ()
|
|
h.nscount = self:word ()
|
|
h.arcount = self:word ()
|
|
|
|
for k,v in pairs (h) do h[k] = v-v%1 end
|
|
|
|
return h
|
|
end
|
|
|
|
|
|
function resolver:name () -- - - - - - - - - - - - - - - - - - - - - - name
|
|
local remember, pointers = nil, 0
|
|
local len = self:byte ()
|
|
local n = {}
|
|
while len > 0 do
|
|
if len >= 0xc0 then -- name is "compressed"
|
|
pointers = pointers + 1
|
|
if pointers >= 20 then error ('dns error: 20 pointers') end
|
|
local offset = ((len-0xc0)*0x100) + self:byte ()
|
|
remember = remember or self.offset
|
|
self.offset = offset + 1 -- +1 for lua
|
|
else -- name is not compressed
|
|
append (n, self:sub (len)..'.')
|
|
end
|
|
len = self:byte ()
|
|
end
|
|
self.offset = remember or self.offset
|
|
return table.concat (n)
|
|
end
|
|
|
|
|
|
function resolver:question () -- - - - - - - - - - - - - - - - - - question
|
|
local q = {}
|
|
q.name = self:name ()
|
|
q.type = dns.type[self:word ()]
|
|
q.class = dns.class[self:word ()]
|
|
return q
|
|
end
|
|
|
|
|
|
function resolver:A (rr) -- - - - - - - - - - - - - - - - - - - - - - - - A
|
|
local b1, b2, b3, b4 = self:byte (4)
|
|
rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4)
|
|
end
|
|
|
|
|
|
function resolver:CNAME (rr) -- - - - - - - - - - - - - - - - - - - - CNAME
|
|
rr.cname = self:name ()
|
|
end
|
|
|
|
|
|
function resolver:MX (rr) -- - - - - - - - - - - - - - - - - - - - - - - MX
|
|
rr.pref = self:word ()
|
|
rr.mx = self:name ()
|
|
end
|
|
|
|
|
|
function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power
|
|
local b = self:byte ()
|
|
--print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
|
|
return ((b-(b%0x10))/0x10) * (10^(b%0x10))
|
|
end
|
|
|
|
|
|
function resolver:LOC (rr) -- - - - - - - - - - - - - - - - - - - - - - LOC
|
|
rr.version = self:byte ()
|
|
if rr.version == 0 then
|
|
rr.loc = rr.loc or {}
|
|
rr.loc.size = self:LOC_nibble_power ()
|
|
rr.loc.horiz_pre = self:LOC_nibble_power ()
|
|
rr.loc.vert_pre = self:LOC_nibble_power ()
|
|
rr.loc.latitude = self:dword ()
|
|
rr.loc.longitude = self:dword ()
|
|
rr.loc.altitude = self:dword ()
|
|
end end
|
|
|
|
|
|
local function LOC_tostring_degrees (f, pos, neg) -- - - - - - - - - - - - -
|
|
f = f - 0x80000000
|
|
if f < 0 then pos = neg f = -f end
|
|
local deg, min, msec
|
|
msec = f%60000
|
|
f = (f-msec)/60000
|
|
min = f%60
|
|
deg = (f-min)/60
|
|
return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos)
|
|
end
|
|
|
|
|
|
function resolver.LOC_tostring (rr) -- - - - - - - - - - - - - LOC_tostring
|
|
|
|
local t = {}
|
|
|
|
--[[
|
|
for k,name in pairs { 'size', 'horiz_pre', 'vert_pre',
|
|
'latitude', 'longitude', 'altitude' } do
|
|
append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name]))
|
|
end
|
|
--]]
|
|
|
|
append ( t, string.format (
|
|
'%s %s %.2fm %.2fm %.2fm %.2fm',
|
|
LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
|
|
LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
|
|
(rr.loc.altitude - 10000000) / 100,
|
|
rr.loc.size / 100,
|
|
rr.loc.horiz_pre / 100,
|
|
rr.loc.vert_pre / 100 ) )
|
|
|
|
return table.concat (t)
|
|
end
|
|
|
|
|
|
function resolver:NS (rr) -- - - - - - - - - - - - - - - - - - - - - - - NS
|
|
rr.ns = self:name ()
|
|
end
|
|
|
|
|
|
function resolver:SOA (rr) -- - - - - - - - - - - - - - - - - - - - - - SOA
|
|
end
|
|
|
|
|
|
function resolver:SRV (rr) -- - - - - - - - - - - - - - - - - - - - - - SRV
|
|
rr.srv = {}
|
|
rr.srv.priority = self:word ()
|
|
rr.srv.weight = self:word ()
|
|
rr.srv.port = self:word ()
|
|
rr.srv.target = self:name ()
|
|
end
|
|
|
|
|
|
function SRV_tostring (rr) -- - - - - - - - - - - - - - - - - - SRV_tostring
|
|
local s = rr.srv
|
|
return string.format ( '%5d %5d %5d %s',
|
|
s.priority, s.weight, s.port, s.target )
|
|
end
|
|
|
|
|
|
function resolver:TXT (rr) -- - - - - - - - - - - - - - - - - - - - - - TXT
|
|
rr.txt = self:sub (rr.rdlength)
|
|
end
|
|
|
|
|
|
function resolver:rr () -- - - - - - - - - - - - - - - - - - - - - - - - rr
|
|
local rr = {}
|
|
setmetatable (rr, rr_metatable)
|
|
rr.name = self:name (self)
|
|
rr.type = dns.type[self:word ()] or rr.type
|
|
rr.class = dns.class[self:word ()] or rr.class
|
|
rr.ttl = 0x10000*self:word () + self:word ()
|
|
rr.rdlength = self:word ()
|
|
|
|
if rr.ttl == 0 then -- pass
|
|
else rr.tod = self.time + rr.ttl end
|
|
|
|
local remember = self.offset
|
|
local rr_parser = self[dns.type[rr.type]]
|
|
if rr_parser then rr_parser (self, rr) end
|
|
self.offset = remember
|
|
rr.rdata = self:sub (rr.rdlength)
|
|
return rr
|
|
end
|
|
|
|
|
|
function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs
|
|
local rrs = {}
|
|
for i = 1,count do append (rrs, self:rr ()) end
|
|
return rrs
|
|
end
|
|
|
|
|
|
function resolver:decode (packet, force) -- - - - - - - - - - - - - - decode
|
|
|
|
self.packet, self.offset = packet, 1
|
|
local header = self:header (force)
|
|
if not header then return nil end
|
|
local response = { header = header }
|
|
|
|
response.question = {}
|
|
local offset = self.offset
|
|
for i = 1,response.header.qdcount do
|
|
append (response.question, self:question ()) end
|
|
response.question.raw = string.sub (self.packet, offset, self.offset - 1)
|
|
|
|
if not force then
|
|
if not self.active[response.header.id] or
|
|
not self.active[response.header.id][response.question.raw] then
|
|
return nil end end
|
|
|
|
response.answer = self:rrs (response.header.ancount)
|
|
response.authority = self:rrs (response.header.nscount)
|
|
response.additional = self:rrs (response.header.arcount)
|
|
|
|
return response
|
|
end
|
|
|
|
|
|
-- socket layer -------------------------------------------------- socket layer
|
|
|
|
|
|
resolver.delays = { 1, 3 }
|
|
|
|
|
|
function resolver:addnameserver (address) -- - - - - - - - - - addnameserver
|
|
self.server = self.server or {}
|
|
append (self.server, address)
|
|
end
|
|
|
|
|
|
function resolver:setnameserver (address) -- - - - - - - - - - setnameserver
|
|
self.server = {}
|
|
self:addnameserver (address)
|
|
end
|
|
|
|
|
|
function resolver:adddefaultnameservers () -- - - - - adddefaultnameservers
|
|
local resolv_conf = io.open("/etc/resolv.conf");
|
|
if resolv_conf then
|
|
for line in resolv_conf:lines() do
|
|
local address = line:gsub("#.*$", ""):match('^%s*nameserver%s+(%d+%.%d+%.%d+%.%d+)%s*$')
|
|
if address then self:addnameserver (address) end
|
|
end
|
|
elseif os.getenv("WINDIR") then
|
|
self:addnameserver ("208.67.222.222")
|
|
self:addnameserver ("208.67.220.220")
|
|
end
|
|
if not self.server or #self.server == 0 then
|
|
self:addnameserver("127.0.0.1");
|
|
end
|
|
end
|
|
|
|
|
|
function resolver:getsocket (servernum) -- - - - - - - - - - - - - getsocket
|
|
|
|
self.socket = self.socket or {}
|
|
self.socketset = self.socketset or {}
|
|
|
|
local sock = self.socket[servernum]
|
|
if sock then return sock end
|
|
|
|
sock = socket.udp ()
|
|
if self.socket_wrapper then sock = self.socket_wrapper (sock, self) end
|
|
sock:settimeout (0)
|
|
-- todo: attempt to use a random port, fallback to 0
|
|
sock:setsockname ('*', 0)
|
|
sock:setpeername (self.server[servernum], 53)
|
|
self.socket[servernum] = sock
|
|
self.socketset[sock] = servernum
|
|
return sock
|
|
end
|
|
|
|
function resolver:voidsocket (sock)
|
|
if self.socket[sock] then
|
|
self.socketset[self.socket[sock]] = nil
|
|
self.socket[sock] = nil
|
|
elseif self.socketset[sock] then
|
|
self.socket[self.socketset[sock]] = nil
|
|
self.socketset[sock] = nil
|
|
end
|
|
end
|
|
|
|
function resolver:socket_wrapper_set (func) -- - - - - - - socket_wrapper_set
|
|
self.socket_wrapper = func
|
|
end
|
|
|
|
|
|
function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall
|
|
for i,sock in ipairs (self.socket) do self.socket[i]:close () end
|
|
self.socket = {}
|
|
end
|
|
|
|
|
|
function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember
|
|
|
|
--print ('remember', type, rr.class, rr.type, rr.name)
|
|
|
|
if type ~= '*' then
|
|
type = rr.type
|
|
local all = get (self.cache, rr.class, '*', rr.name)
|
|
--print ('remember all', all)
|
|
if all then append (all, rr) end
|
|
end
|
|
|
|
self.cache = self.cache or setmetatable ({}, cache_metatable)
|
|
local rrs = get (self.cache, rr.class, type, rr.name) or
|
|
set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable))
|
|
append (rrs, rr)
|
|
|
|
if type == 'MX' then self.unsorted[rrs] = true end
|
|
end
|
|
|
|
|
|
local function comp_mx (a, b) -- - - - - - - - - - - - - - - - - - - comp_mx
|
|
return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref)
|
|
end
|
|
|
|
|
|
function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek
|
|
qname, qtype, qclass = standardize (qname, qtype, qclass)
|
|
local rrs = get (self.cache, qclass, qtype, qname)
|
|
if not rrs then return nil end
|
|
if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then
|
|
set (self.cache, qclass, qtype, qname, nil) return nil end
|
|
if self.unsorted[rrs] then table.sort (rrs, comp_mx) end
|
|
return rrs
|
|
end
|
|
|
|
|
|
function resolver:purge (soft) -- - - - - - - - - - - - - - - - - - - purge
|
|
if soft == 'soft' then
|
|
self.time = socket.gettime ()
|
|
for class,types in pairs (self.cache or {}) do
|
|
for type,names in pairs (types) do
|
|
for name,rrs in pairs (names) do
|
|
prune (rrs, self.time, 'soft')
|
|
end end end
|
|
else self.cache = {} end
|
|
end
|
|
|
|
|
|
function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query
|
|
|
|
qname, qtype, qclass = standardize (qname, qtype, qclass)
|
|
|
|
if not self.server then self:adddefaultnameservers () end
|
|
|
|
local question = encodeQuestion (qname, qtype, qclass)
|
|
local peek = self:peek (qname, qtype, qclass)
|
|
if peek then return peek end
|
|
|
|
local header, id = encodeHeader ()
|
|
--print ('query id', id, qclass, qtype, qname)
|
|
local o = { packet = header..question,
|
|
server = self.best_server,
|
|
delay = 1,
|
|
retry = socket.gettime () + self.delays[1] }
|
|
|
|
-- remember the query
|
|
self.active[id] = self.active[id] or {}
|
|
self.active[id][question] = o
|
|
|
|
-- remember which coroutine wants the answer
|
|
local co = coroutine.running ()
|
|
if co then
|
|
set (self.wanted, qclass, qtype, qname, co, true)
|
|
--set (self.yielded, co, qclass, qtype, qname, true)
|
|
end
|
|
|
|
self:getsocket (o.server):send (o.packet)
|
|
|
|
end
|
|
|
|
function resolver:servfail(sock)
|
|
-- Resend all queries for this server
|
|
|
|
local num = self.socketset[sock]
|
|
|
|
-- Socket is dead now
|
|
self:voidsocket(sock);
|
|
|
|
-- Find all requests to the down server, and retry on the next server
|
|
self.time = socket.gettime ()
|
|
for id,queries in pairs (self.active) do
|
|
for question,o in pairs (queries) do
|
|
if o.server == num then -- This request was to the broken server
|
|
o.server = o.server + 1 -- Use next server
|
|
if o.server > #self.server then
|
|
o.server = 1
|
|
end
|
|
|
|
o.retries = (o.retries or 0) + 1;
|
|
if o.retries >= #self.server then
|
|
--print ('timeout')
|
|
queries[question] = nil
|
|
else
|
|
local _a = self:getsocket(o.server);
|
|
if _a then _a:send (o.packet) end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
if num == self.best_server then
|
|
self.best_server = self.best_server + 1
|
|
if self.best_server > #self.server then
|
|
-- Exhausted all servers, try first again
|
|
self.best_server = 1
|
|
end
|
|
end
|
|
end
|
|
|
|
function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive
|
|
|
|
--print 'receive' print (self.socket)
|
|
self.time = socket.gettime ()
|
|
rset = rset or self.socket
|
|
|
|
local response
|
|
for i,sock in pairs (rset) do
|
|
|
|
if self.socketset[sock] then
|
|
local packet = sock:receive ()
|
|
if packet then
|
|
|
|
response = self:decode (packet)
|
|
if response then
|
|
--print 'received response'
|
|
--self.print (response)
|
|
|
|
for i,section in pairs { 'answer', 'authority', 'additional' } do
|
|
for j,rr in pairs (response[section]) do
|
|
self:remember (rr, response.question[1].type) end end
|
|
|
|
-- retire the query
|
|
local queries = self.active[response.header.id]
|
|
if queries[response.question.raw] then
|
|
queries[response.question.raw] = nil end
|
|
if not next (queries) then self.active[response.header.id] = nil end
|
|
if not next (self.active) then self:closeall () end
|
|
|
|
-- was the query on the wanted list?
|
|
local q = response.question
|
|
local cos = get (self.wanted, q.class, q.type, q.name)
|
|
if cos then
|
|
for co in pairs (cos) do
|
|
set (self.yielded, co, q.class, q.type, q.name, nil)
|
|
if coroutine.status(co) == "suspended" then coroutine.resume (co) end
|
|
end
|
|
set (self.wanted, q.class, q.type, q.name, nil)
|
|
end end end end end
|
|
|
|
return response
|
|
end
|
|
|
|
|
|
function resolver:feed(sock, packet)
|
|
--print 'receive' print (self.socket)
|
|
self.time = socket.gettime ()
|
|
|
|
local response = self:decode (packet)
|
|
if response then
|
|
--print 'received response'
|
|
--self.print (response)
|
|
|
|
for i,section in pairs { 'answer', 'authority', 'additional' } do
|
|
for j,rr in pairs (response[section]) do
|
|
self:remember (rr, response.question[1].type)
|
|
end
|
|
end
|
|
|
|
-- retire the query
|
|
local queries = self.active[response.header.id]
|
|
if queries[response.question.raw] then
|
|
queries[response.question.raw] = nil
|
|
end
|
|
if not next (queries) then self.active[response.header.id] = nil end
|
|
if not next (self.active) then self:closeall () end
|
|
|
|
-- was the query on the wanted list?
|
|
local q = response.question[1]
|
|
if q then
|
|
local cos = get (self.wanted, q.class, q.type, q.name)
|
|
if cos then
|
|
for co in pairs (cos) do
|
|
set (self.yielded, co, q.class, q.type, q.name, nil)
|
|
if coroutine.status(co) == "suspended" then coroutine.resume (co) end
|
|
end
|
|
set (self.wanted, q.class, q.type, q.name, nil)
|
|
end
|
|
end
|
|
end
|
|
|
|
return response
|
|
end
|
|
|
|
function resolver:cancel(data)
|
|
local cos = get (self.wanted, unpack(data, 1, 3))
|
|
if cos then
|
|
cos[data[4]] = nil;
|
|
end
|
|
end
|
|
|
|
function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse
|
|
|
|
--print ':pulse'
|
|
while self:receive() do end
|
|
if not next (self.active) then return nil end
|
|
|
|
self.time = socket.gettime ()
|
|
for id,queries in pairs (self.active) do
|
|
for question,o in pairs (queries) do
|
|
if self.time >= o.retry then
|
|
|
|
o.server = o.server + 1
|
|
if o.server > #self.server then
|
|
o.server = 1
|
|
o.delay = o.delay + 1
|
|
end
|
|
|
|
if o.delay > #self.delays then
|
|
--print ('timeout')
|
|
queries[question] = nil
|
|
if not next (queries) then self.active[id] = nil end
|
|
if not next (self.active) then return nil end
|
|
else
|
|
--print ('retry', o.server, o.delay)
|
|
local _a = self.socket[o.server];
|
|
if _a then _a:send (o.packet) end
|
|
o.retry = self.time + self.delays[o.delay]
|
|
end end end end
|
|
|
|
if next (self.active) then return true end
|
|
return nil
|
|
end
|
|
|
|
|
|
function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup
|
|
self:query (qname, qtype, qclass)
|
|
while self:pulse () do socket.select (self.socket, nil, 4) end
|
|
--print (self.cache)
|
|
return self:peek (qname, qtype, qclass)
|
|
end
|
|
|
|
function resolver:lookupex (handler, qname, qtype, qclass) -- - - - - - - - - - lookup
|
|
return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
|
|
end
|
|
|
|
|
|
--print ---------------------------------------------------------------- print
|
|
|
|
|
|
local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
|
|
qr = { [0]='query', 'response' },
|
|
opcode = { [0]='query', 'inverse query', 'server status request' },
|
|
aa = { [0]='non-authoritative', 'authoritative' },
|
|
tc = { [0]='complete', 'truncated' },
|
|
rd = { [0]='recursion not desired', 'recursion desired' },
|
|
ra = { [0]='recursion not available', 'recursion available' },
|
|
z = { [0]='(reserved)' },
|
|
rcode = { [0]='no error', 'format error', 'server failure', 'name error',
|
|
'not implemented' },
|
|
|
|
type = dns.type,
|
|
class = dns.class, }
|
|
|
|
|
|
local function hint (p, s) -- - - - - - - - - - - - - - - - - - - - - - hint
|
|
return (hints[s] and hints[s][p[s]]) or '' end
|
|
|
|
|
|
function resolver.print (response) -- - - - - - - - - - - - - resolver.print
|
|
|
|
for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
|
|
'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
|
|
print ( string.format ('%-30s', 'header.'..s),
|
|
response.header[s], hint (response.header, s) )
|
|
end
|
|
|
|
for i,question in ipairs (response.question) do
|
|
print (string.format ('question[%i].name ', i), question.name)
|
|
print (string.format ('question[%i].type ', i), question.type)
|
|
print (string.format ('question[%i].class ', i), question.class)
|
|
end
|
|
|
|
local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }
|
|
local tmp
|
|
for s,s in pairs {'answer', 'authority', 'additional'} do
|
|
for i,rr in pairs (response[s]) do
|
|
for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do
|
|
tmp = string.format ('%s[%i].%s', s, i, t)
|
|
print (string.format ('%-30s', tmp), rr[t], hint (rr, t))
|
|
end
|
|
for j,t in pairs (rr) do
|
|
if not common[j] then
|
|
tmp = string.format ('%s[%i].%s', s, i, j)
|
|
print (string.format ('%-30s %s', tostring(tmp), tostring(t)))
|
|
end end end end end
|
|
|
|
|
|
-- module api ------------------------------------------------------ module api
|
|
|
|
|
|
local function resolve (func, ...) -- - - - - - - - - - - - - - resolver_get
|
|
dns._resolver = dns._resolver or dns.resolver ()
|
|
return func (dns._resolver, ...)
|
|
end
|
|
|
|
|
|
function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver
|
|
|
|
-- this function seems to be redundant with resolver.new ()
|
|
|
|
local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {},
|
|
best_server = 1 }
|
|
setmetatable (r, resolver)
|
|
setmetatable (r.cache, cache_metatable)
|
|
setmetatable (r.unsorted, { __mode = 'kv' })
|
|
return r
|
|
end
|
|
|
|
|
|
function dns.lookup (...) -- - - - - - - - - - - - - - - - - - - - - lookup
|
|
return resolve (resolver.lookup, ...) end
|
|
|
|
|
|
function dns.purge (...) -- - - - - - - - - - - - - - - - - - - - - - purge
|
|
return resolve (resolver.purge, ...) end
|
|
|
|
function dns.peek (...) -- - - - - - - - - - - - - - - - - - - - - - - peek
|
|
return resolve (resolver.peek, ...) end
|
|
|
|
|
|
function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query
|
|
return resolve (resolver.query, ...) end
|
|
|
|
function dns.feed (...) -- - - - - - - - - - - - - - - - - - - - - - feed
|
|
return resolve (resolver.feed, ...) end
|
|
|
|
function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel
|
|
return resolve(resolver.cancel, ...) end
|
|
|
|
function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set
|
|
return resolve (resolver.socket_wrapper_set, ...) end
|
|
|
|
|
|
return dns
|