mirror of
https://github.com/bjc/prosody.git
synced 2025-04-03 21:27:38 +03:00
There were 3 very similar methods: - :execute() - :execute_query() - :execute_update() The first one returns the prepared statement and is mainly used internally in the library for CREATE statements. The later two only really differ in how the results are returned. Those two are one main method and one small one that only picks out the iterator.
408 lines
14 KiB
Lua
408 lines
14 KiB
Lua
|
|
-- luacheck: ignore 113/unpack 211 212 411 213
|
|
local setmetatable, getmetatable = setmetatable, getmetatable;
|
|
local ipairs, unpack, select = ipairs, table.unpack or unpack, select;
|
|
local tonumber, tostring = tonumber, tostring;
|
|
local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
|
|
local error = error
|
|
local type = type
|
|
local t_concat = table.concat;
|
|
local t_insert = table.insert;
|
|
local s_char = string.char;
|
|
local array = require "prosody.util.array";
|
|
local log = require "prosody.util.logger".init("sql");
|
|
|
|
local lsqlite3 = require "lsqlite3";
|
|
local build_url = require "socket.url".build;
|
|
local ROW, DONE = lsqlite3.ROW, lsqlite3.DONE;
|
|
|
|
-- from sqlite3.h, no copyright claimed
|
|
local sqlite_errors = require"prosody.util.error".init("util.sqlite3", {
|
|
-- FIXME xmpp error conditions?
|
|
[1] = { code = 1; type = "modify"; condition = "ERROR"; text = "Generic error" };
|
|
[2] = { code = 2; type = "cancel"; condition = "INTERNAL"; text = "Internal logic error in SQLite" };
|
|
[3] = { code = 3; type = "auth"; condition = "PERM"; text = "Access permission denied" };
|
|
[4] = { code = 4; type = "cancel"; condition = "ABORT"; text = "Callback routine requested an abort" };
|
|
[5] = { code = 5; type = "wait"; condition = "BUSY"; text = "The database file is locked" };
|
|
[6] = { code = 6; type = "wait"; condition = "LOCKED"; text = "A table in the database is locked" };
|
|
[7] = { code = 7; type = "wait"; condition = "NOMEM"; text = "A malloc() failed" };
|
|
[8] = { code = 8; type = "cancel"; condition = "READONLY"; text = "Attempt to write a readonly database" };
|
|
[9] = { code = 9; type = "cancel"; condition = "INTERRUPT"; text = "Operation terminated by sqlite3_interrupt()" };
|
|
[10] = { code = 10; type = "wait"; condition = "IOERR"; text = "Some kind of disk I/O error occurred" };
|
|
[11] = { code = 11; type = "cancel"; condition = "CORRUPT"; text = "The database disk image is malformed" };
|
|
[12] = { code = 12; type = "modify"; condition = "NOTFOUND"; text = "Unknown opcode in sqlite3_file_control()" };
|
|
[13] = { code = 13; type = "wait"; condition = "FULL"; text = "Insertion failed because database is full" };
|
|
[14] = { code = 14; type = "auth"; condition = "CANTOPEN"; text = "Unable to open the database file" };
|
|
[15] = { code = 15; type = "cancel"; condition = "PROTOCOL"; text = "Database lock protocol error" };
|
|
[16] = { code = 16; type = "continue"; condition = "EMPTY"; text = "Internal use only" };
|
|
[17] = { code = 17; type = "modify"; condition = "SCHEMA"; text = "The database schema changed" };
|
|
[18] = { code = 18; type = "modify"; condition = "TOOBIG"; text = "String or BLOB exceeds size limit" };
|
|
[19] = { code = 19; type = "modify"; condition = "CONSTRAINT"; text = "Abort due to constraint violation" };
|
|
[20] = { code = 20; type = "modify"; condition = "MISMATCH"; text = "Data type mismatch" };
|
|
[21] = { code = 21; type = "modify"; condition = "MISUSE"; text = "Library used incorrectly" };
|
|
[22] = { code = 22; type = "cancel"; condition = "NOLFS"; text = "Uses OS features not supported on host" };
|
|
[23] = { code = 23; type = "auth"; condition = "AUTH"; text = "Authorization denied" };
|
|
[24] = { code = 24; type = "modify"; condition = "FORMAT"; text = "Not used" };
|
|
[25] = { code = 25; type = "modify"; condition = "RANGE"; text = "2nd parameter to sqlite3_bind out of range" };
|
|
[26] = { code = 26; type = "cancel"; condition = "NOTADB"; text = "File opened that is not a database file" };
|
|
[27] = { code = 27; type = "continue"; condition = "NOTICE"; text = "Notifications from sqlite3_log()" };
|
|
[28] = { code = 28; type = "continue"; condition = "WARNING"; text = "Warnings from sqlite3_log()" };
|
|
[100] = { code = 100; type = "continue"; condition = "ROW"; text = "sqlite3_step() has another row ready" };
|
|
[101] = { code = 101; type = "continue"; condition = "DONE"; text = "sqlite3_step() has finished executing" };
|
|
});
|
|
|
|
local assert = function(cond, errno, err)
|
|
return assert(sqlite_errors.coerce(cond, err or errno));
|
|
end
|
|
local _ENV = nil;
|
|
-- luacheck: std none
|
|
|
|
local column_mt = {};
|
|
local table_mt = {};
|
|
local query_mt = {};
|
|
--local op_mt = {};
|
|
local index_mt = {};
|
|
|
|
local function is_column(x) return getmetatable(x)==column_mt; end
|
|
local function is_index(x) return getmetatable(x)==index_mt; end
|
|
local function is_table(x) return getmetatable(x)==table_mt; end
|
|
local function is_query(x) return getmetatable(x)==query_mt; end
|
|
local function Integer(n) return "Integer()" end
|
|
local function String(n) return "String()" end
|
|
|
|
local function Column(definition)
|
|
return setmetatable(definition, column_mt);
|
|
end
|
|
local function Table(definition)
|
|
local c = {}
|
|
for i,col in ipairs(definition) do
|
|
if is_column(col) then
|
|
c[i], c[col.name] = col, col;
|
|
elseif is_index(col) then
|
|
col.table = definition.name;
|
|
end
|
|
end
|
|
return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
|
|
end
|
|
local function Index(definition)
|
|
return setmetatable(definition, index_mt);
|
|
end
|
|
|
|
function table_mt:__tostring()
|
|
local s = { 'name="'..self.__table__.name..'"' }
|
|
for i,col in ipairs(self.__table__) do
|
|
s[#s+1] = tostring(col);
|
|
end
|
|
return 'Table{ '..t_concat(s, ", ")..' }'
|
|
end
|
|
table_mt.__index = {};
|
|
function table_mt.__index:create(engine)
|
|
return engine:_create_table(self);
|
|
end
|
|
function table_mt:__call(...)
|
|
-- TODO
|
|
end
|
|
function column_mt:__tostring()
|
|
return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
|
|
end
|
|
function index_mt:__tostring()
|
|
local s = 'Index{ name="'..self.name..'"';
|
|
for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
|
|
return s..' }';
|
|
-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
|
|
end
|
|
|
|
local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
|
|
local function parse_url(url)
|
|
local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
|
|
assert(scheme, "Invalid URL format");
|
|
local username, password, host, port;
|
|
local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
|
|
if not authpart then hostpart = secondpart; end
|
|
if authpart then
|
|
username, password = authpart:match("([^:]*):(.*)");
|
|
username = username or authpart;
|
|
password = password and urldecode(password);
|
|
end
|
|
if hostpart then
|
|
host, port = hostpart:match("([^:]*):(.*)");
|
|
host = host or hostpart;
|
|
port = port and assert(tonumber(port), "Invalid URL format");
|
|
end
|
|
return {
|
|
scheme = scheme:lower();
|
|
username = username; password = password;
|
|
host = host; port = port;
|
|
database = #database > 0 and database or nil;
|
|
};
|
|
end
|
|
|
|
local engine = {};
|
|
function engine:connect()
|
|
if self.conn then return true; end
|
|
|
|
local params = self.params;
|
|
assert(params.driver == "SQLite3", "Only sqlite3 is supported");
|
|
local dbh, err = sqlite_errors.coerce(lsqlite3.open(params.database));
|
|
if not dbh then return nil, err; end
|
|
self.conn = dbh;
|
|
self.prepared = {};
|
|
local ok, err = self:set_encoding();
|
|
if not ok then
|
|
return ok, err;
|
|
end
|
|
local ok, err = self:onconnect();
|
|
if ok == false then
|
|
return ok, err;
|
|
end
|
|
return true;
|
|
end
|
|
function engine:onconnect()
|
|
-- Override from create_engine()
|
|
end
|
|
function engine:ondisconnect() -- luacheck: ignore 212/self
|
|
-- Override from create_engine()
|
|
end
|
|
function engine:execute(sql, ...)
|
|
local success, err = self:connect();
|
|
if not success then return success, err; end
|
|
local prepared = self.prepared;
|
|
|
|
if select('#', ...) == 0 then
|
|
local ret = self.conn:exec(sql);
|
|
if ret ~= lsqlite3.OK then
|
|
local err = sqlite_errors.new(err);
|
|
err.text = self.conn:errmsg();
|
|
return err;
|
|
end
|
|
return true;
|
|
end
|
|
|
|
local stmt = prepared[sql];
|
|
if not stmt then
|
|
local err;
|
|
stmt, err = self.conn:prepare(sql);
|
|
if not stmt then
|
|
err = sqlite_errors.new(err);
|
|
err.text = self.conn:errmsg();
|
|
return stmt, err;
|
|
end
|
|
prepared[sql] = stmt;
|
|
end
|
|
|
|
local ret = stmt:bind_values(...);
|
|
if ret ~= lsqlite3.OK then return nil, sqlite_errors.new(ret, { message = self.conn:errmsg() }); end
|
|
return stmt;
|
|
end
|
|
|
|
local function iterator(table)
|
|
local i = 0;
|
|
return function()
|
|
i = i + 1;
|
|
local item = table[i];
|
|
if item ~= nil then
|
|
return item;
|
|
end
|
|
end
|
|
end
|
|
|
|
local result_mt = {
|
|
__len = function(self)
|
|
return self.__rowcount;
|
|
end;
|
|
__index = {
|
|
affected = function(self)
|
|
return self.__affected;
|
|
end;
|
|
rowcount = function(self)
|
|
return self.__rowcount;
|
|
end;
|
|
};
|
|
__call = function(self)
|
|
return iterator(self.__data);
|
|
end;
|
|
};
|
|
|
|
local function debugquery(where, sql, ...)
|
|
local i = 0; local a = {...}
|
|
sql = sql:gsub("\n?\t+", " ");
|
|
log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
|
|
i = i + 1;
|
|
local v = a[i];
|
|
if type(v) == "string" then
|
|
v = ("'%s'"):format(v:gsub("'", "''"));
|
|
end
|
|
return tostring(v);
|
|
end)));
|
|
end
|
|
|
|
function engine:execute_update(sql, ...)
|
|
local prepared = self.prepared;
|
|
local stmt = prepared[sql];
|
|
if stmt and stmt:isopen() then
|
|
prepared[sql] = nil; -- Can't be used concurrently
|
|
else
|
|
stmt = assert(self.conn:prepare(sql));
|
|
end
|
|
local ret = stmt:bind_values(...);
|
|
if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
|
|
local data = array();
|
|
for row in stmt:rows() do
|
|
data:push(array(row));
|
|
end
|
|
-- FIXME Error handling, BUSY, ERROR, MISUSE
|
|
if stmt:reset() == lsqlite3.OK then
|
|
prepared[sql] = stmt;
|
|
end
|
|
local affected = self.conn:changes();
|
|
return setmetatable({ __affected = affected; __rowcount = #data; __data = data }, result_mt);
|
|
end
|
|
|
|
function engine:execute_query(sql, ...)
|
|
return self:execute_update(sql, ...)()
|
|
end
|
|
|
|
engine.insert = engine.execute_update;
|
|
engine.select = engine.execute_query;
|
|
engine.delete = engine.execute_update;
|
|
engine.update = engine.execute_update;
|
|
local function debugwrap(name, f)
|
|
return function (self, sql, ...)
|
|
debugquery(name, sql, ...)
|
|
return f(self, sql, ...)
|
|
end
|
|
end
|
|
function engine:debug(enable)
|
|
self._debug = enable;
|
|
if enable then
|
|
engine.insert = debugwrap("insert", engine.execute_update);
|
|
engine.select = debugwrap("select", engine.execute_query);
|
|
engine.delete = debugwrap("delete", engine.execute_update);
|
|
engine.update = debugwrap("update", engine.execute_update);
|
|
else
|
|
engine.insert = engine.execute_update;
|
|
engine.select = engine.execute_query;
|
|
engine.delete = engine.execute_update;
|
|
engine.update = engine.execute_update;
|
|
end
|
|
end
|
|
function engine:_(word)
|
|
local ret = self.conn:exec(word);
|
|
if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end
|
|
return true;
|
|
end
|
|
function engine:_transaction(func, ...)
|
|
if not self.conn then
|
|
local a,b = self:connect();
|
|
if not a then return a,b; end
|
|
end
|
|
--assert(not self.__transaction, "Recursive transactions not allowed");
|
|
local ok, err = self:_"BEGIN";
|
|
if not ok then return ok, err; end
|
|
self.__transaction = true;
|
|
local success, a, b, c = xpcall(func, debug_traceback, ...);
|
|
self.__transaction = nil;
|
|
if success then
|
|
log("debug", "SQL transaction success [%s]", tostring(func));
|
|
local ok, err = self:_"COMMIT";
|
|
if not ok then return ok, err; end -- commit failed
|
|
return success, a, b, c;
|
|
else
|
|
log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
|
|
if self.conn then self:_"ROLLBACK"; end
|
|
return success, a;
|
|
end
|
|
end
|
|
function engine:transaction(...)
|
|
local ok, ret = self:_transaction(...);
|
|
if not ok then
|
|
local conn = self.conn;
|
|
if not conn or not conn:isopen() then
|
|
self.conn = nil;
|
|
self:ondisconnect();
|
|
ok, ret = self:_transaction(...);
|
|
end
|
|
end
|
|
return ok, ret;
|
|
end
|
|
function engine:_create_index(index)
|
|
local sql = "CREATE INDEX IF NOT EXISTS \""..index.name.."\" ON \""..index.table.."\" (";
|
|
for i=1,#index do
|
|
sql = sql.."\""..index[i].."\"";
|
|
if i ~= #index then sql = sql..", "; end
|
|
end
|
|
sql = sql..");"
|
|
if index.unique then
|
|
sql = sql:gsub("^CREATE", "CREATE UNIQUE");
|
|
end
|
|
if self._debug then
|
|
debugquery("create", sql);
|
|
end
|
|
return self:execute(sql);
|
|
end
|
|
function engine:_create_table(table)
|
|
local sql = "CREATE TABLE IF NOT EXISTS \""..table.name.."\" (";
|
|
for i,col in ipairs(table.c) do
|
|
local col_type = col.type;
|
|
sql = sql.."\""..col.name.."\" "..col_type;
|
|
if col.nullable == false then sql = sql.." NOT NULL"; end
|
|
if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
|
|
if col.auto_increment == true then
|
|
sql = sql.." AUTOINCREMENT";
|
|
end
|
|
if i ~= #table.c then sql = sql..", "; end
|
|
end
|
|
sql = sql.. ");"
|
|
if self._debug then
|
|
debugquery("create", sql);
|
|
end
|
|
local success,err = self:execute(sql);
|
|
if not success then return success,err; end
|
|
for i,v in ipairs(table.__table__) do
|
|
if is_index(v) then
|
|
self:_create_index(v);
|
|
end
|
|
end
|
|
return success;
|
|
end
|
|
|
|
function engine:set_encoding() -- to UTF-8
|
|
return self:transaction(function()
|
|
for encoding in self:select "PRAGMA encoding;" do
|
|
if encoding[1] == "UTF-8" then
|
|
self.charset = "utf8";
|
|
end
|
|
end
|
|
end);
|
|
end
|
|
local engine_mt = { __index = engine };
|
|
|
|
local function db2uri(params)
|
|
return build_url{
|
|
scheme = params.driver,
|
|
user = params.username,
|
|
password = params.password,
|
|
host = params.host,
|
|
port = params.port,
|
|
path = params.database,
|
|
};
|
|
end
|
|
|
|
local function create_engine(_, params, onconnect, ondisconnect)
|
|
assert(params.driver == "SQLite3", "Only SQLite3 is supported without LuaDBI");
|
|
return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt);
|
|
end
|
|
|
|
return {
|
|
is_column = is_column;
|
|
is_index = is_index;
|
|
is_table = is_table;
|
|
is_query = is_query;
|
|
Integer = Integer;
|
|
String = String;
|
|
Column = Column;
|
|
Table = Table;
|
|
Index = Index;
|
|
create_engine = create_engine;
|
|
db2uri = db2uri;
|
|
};
|