mirror of
https://github.com/bjc/prosody.git
synced 2025-04-01 20:27:39 +03:00
This was copied from util.sqlite3 which only deals with SQLite3, but here Postgres or MySQL are also possibilities, which wouldn't support this.
390 lines
11 KiB
Lua
390 lines
11 KiB
Lua
|
|
local setmetatable, getmetatable = setmetatable, getmetatable;
|
|
local ipairs = ipairs;
|
|
local tostring = tostring;
|
|
local type = type;
|
|
local assert, pcall, debug_traceback = assert, pcall, debug.traceback;
|
|
local xpcall = require "prosody.util.xpcall".xpcall;
|
|
local t_concat = table.concat;
|
|
local log = require "prosody.util.logger".init("sql");
|
|
|
|
local DBI = require "DBI";
|
|
-- This loads all available drivers while globals are unlocked
|
|
-- LuaDBI should be fixed to not set globals.
|
|
DBI.Drivers();
|
|
local build_url = require "socket.url".build;
|
|
|
|
local _ENV = nil;
|
|
-- luacheck: std none
|
|
|
|
local column_mt = {};
|
|
local table_mt = {};
|
|
local query_mt = {};
|
|
--local op_mt = {};
|
|
local index_mt = {};
|
|
|
|
local function is_column(x) return getmetatable(x)==column_mt; end
|
|
local function is_index(x) return getmetatable(x)==index_mt; end
|
|
local function is_table(x) return getmetatable(x)==table_mt; end
|
|
local function is_query(x) return getmetatable(x)==query_mt; end
|
|
|
|
local function Column(definition)
|
|
return setmetatable(definition, column_mt);
|
|
end
|
|
local function Table(definition)
|
|
local c = {}
|
|
for i,col in ipairs(definition) do
|
|
if is_column(col) then
|
|
c[i], c[col.name] = col, col;
|
|
elseif is_index(col) then
|
|
col.table = definition.name;
|
|
end
|
|
end
|
|
return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
|
|
end
|
|
local function Index(definition)
|
|
return setmetatable(definition, index_mt);
|
|
end
|
|
|
|
function table_mt:__tostring()
|
|
local s = { 'name="'..self.__table__.name..'"' }
|
|
for _, col in ipairs(self.__table__) do
|
|
s[#s+1] = tostring(col);
|
|
end
|
|
return 'Table{ '..t_concat(s, ", ")..' }'
|
|
end
|
|
table_mt.__index = {};
|
|
function table_mt.__index:create(engine)
|
|
return engine:_create_table(self);
|
|
end
|
|
function column_mt:__tostring()
|
|
return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
|
|
end
|
|
function index_mt:__tostring()
|
|
local s = 'Index{ name="'..self.name..'"';
|
|
for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
|
|
return s..' }';
|
|
-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
|
|
end
|
|
|
|
local engine = {};
|
|
function engine:connect()
|
|
if self.conn then return true; end
|
|
|
|
local params = self.params;
|
|
assert(params.driver, "no driver")
|
|
log("debug", "Connecting to [%s] %s...", params.driver, params.database);
|
|
local ok, dbh, err = pcall(DBI.Connect,
|
|
params.driver, params.database,
|
|
params.username, params.password,
|
|
params.host, params.port
|
|
);
|
|
if not ok then return ok, dbh; end
|
|
if not dbh then return nil, err; end
|
|
dbh:autocommit(false); -- don't commit automatically
|
|
self.conn = dbh;
|
|
self.prepared = {};
|
|
if params.driver == "SQLite3" and params.password then
|
|
local ok, err = self:execute(("PRAGMA key='%s'"):format(dbh:quote(params.password)));
|
|
if not ok then
|
|
return ok, err;
|
|
end
|
|
end
|
|
local ok, err = self:set_encoding();
|
|
if not ok then
|
|
return ok, err;
|
|
end
|
|
local ok, err = self:onconnect();
|
|
if ok == false then
|
|
return ok, err;
|
|
end
|
|
return true;
|
|
end
|
|
function engine:onconnect() -- luacheck: ignore 212/self
|
|
-- Override from create_engine()
|
|
end
|
|
function engine:ondisconnect() -- luacheck: ignore 212/self
|
|
-- Override from create_engine()
|
|
end
|
|
|
|
function engine:prepquery(sql)
|
|
if self.params.driver == "MySQL" then
|
|
sql = sql:gsub("\"", "`");
|
|
end
|
|
return sql;
|
|
end
|
|
|
|
function engine:execute(sql, ...)
|
|
local success, err = self:connect();
|
|
if not success then return success, err; end
|
|
local prepared = self.prepared;
|
|
|
|
sql = self:prepquery(sql);
|
|
local stmt = prepared[sql];
|
|
if not stmt then
|
|
local err;
|
|
stmt, err = self.conn:prepare(sql);
|
|
if not stmt then return stmt, err; end
|
|
prepared[sql] = stmt;
|
|
end
|
|
|
|
-- luacheck: ignore 411/success
|
|
local success, err = stmt:execute(...);
|
|
if not success then return success, err; end
|
|
return stmt;
|
|
end
|
|
|
|
local result_mt = { __index = {
|
|
affected = function(self) return self.__stmt:affected(); end;
|
|
rowcount = function(self) return self.__stmt:rowcount(); end;
|
|
} };
|
|
|
|
local function debugquery(where, sql, ...)
|
|
local i = 0; local a = {...}
|
|
sql = sql:gsub("\n?\t+", " ");
|
|
log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
|
|
i = i + 1;
|
|
local v = a[i];
|
|
if type(v) == "string" then
|
|
v = ("'%s'"):format(v:gsub("'", "''"));
|
|
end
|
|
return tostring(v);
|
|
end)));
|
|
end
|
|
|
|
function engine:execute_query(sql, ...)
|
|
sql = self:prepquery(sql);
|
|
local stmt = assert(self.conn:prepare(sql));
|
|
assert(stmt:execute(...));
|
|
local result = {};
|
|
for row in stmt:rows() do result[#result + 1] = row; end
|
|
stmt:close();
|
|
local i = 0;
|
|
return function() i=i+1; return result[i]; end;
|
|
end
|
|
function engine:execute_update(sql, ...)
|
|
sql = self:prepquery(sql);
|
|
local prepared = self.prepared;
|
|
local stmt = prepared[sql];
|
|
if not stmt then
|
|
stmt = assert(self.conn:prepare(sql));
|
|
prepared[sql] = stmt;
|
|
end
|
|
assert(stmt:execute(...));
|
|
return setmetatable({ __stmt = stmt }, result_mt);
|
|
end
|
|
engine.insert = engine.execute_update;
|
|
engine.select = engine.execute_query;
|
|
engine.delete = engine.execute_update;
|
|
engine.update = engine.execute_update;
|
|
local function debugwrap(name, f)
|
|
return function (self, sql, ...)
|
|
debugquery(name, sql, ...)
|
|
return f(self, sql, ...)
|
|
end
|
|
end
|
|
function engine:debug(enable)
|
|
self._debug = enable;
|
|
if enable then
|
|
engine.insert = debugwrap("insert", engine.execute_update);
|
|
engine.select = debugwrap("select", engine.execute_query);
|
|
engine.delete = debugwrap("delete", engine.execute_update);
|
|
engine.update = debugwrap("update", engine.execute_update);
|
|
else
|
|
engine.insert = engine.execute_update;
|
|
engine.select = engine.execute_query;
|
|
engine.delete = engine.execute_update;
|
|
engine.update = engine.execute_update;
|
|
end
|
|
end
|
|
local function handleerr(err)
|
|
local trace = debug_traceback(err, 3);
|
|
log("debug", "Error in SQL transaction: %s", trace);
|
|
return { err = err, traceback = trace };
|
|
end
|
|
function engine:_transaction(func, ...)
|
|
if not self.conn then
|
|
local ok, err = self:connect();
|
|
if not ok then return ok, err; end
|
|
end
|
|
--assert(not self.__transaction, "Recursive transactions not allowed");
|
|
log("debug", "SQL transaction begin [%s]", func);
|
|
self.__transaction = true;
|
|
local success, a, b, c = xpcall(func, handleerr, ...);
|
|
self.__transaction = nil;
|
|
if success then
|
|
log("debug", "SQL transaction success [%s]", func);
|
|
local ok, err = self.conn:commit();
|
|
-- LuaDBI doesn't actually return an error message here, just a boolean
|
|
if not ok then return ok, err or "commit failed"; end
|
|
return success, a, b, c;
|
|
else
|
|
log("debug", "SQL transaction failure [%s]: %s", func, a.err);
|
|
if self.conn then self.conn:rollback(); end
|
|
return success, a.err;
|
|
end
|
|
end
|
|
function engine:transaction(...)
|
|
local ok, ret, b, c = self:_transaction(...);
|
|
if not ok then
|
|
local conn = self.conn;
|
|
if not conn or not conn:ping() then
|
|
log("debug", "Database connection was closed. Will reconnect and retry.");
|
|
self.conn = nil;
|
|
self:ondisconnect();
|
|
log("debug", "Retrying SQL transaction [%s]", (...));
|
|
ok, ret, b, c = self:_transaction(...);
|
|
log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed");
|
|
else
|
|
log("debug", "SQL connection is up, so not retrying");
|
|
end
|
|
if not ok then
|
|
log("error", "Error in SQL transaction: %s", ret);
|
|
end
|
|
end
|
|
return ok, ret, b, c;
|
|
end
|
|
function engine:_create_index(index)
|
|
local sql = "CREATE INDEX \""..index.name.."\" ON \""..index.table.."\" (";
|
|
if self.params.driver ~= "MySQL" then
|
|
sql = sql:gsub("^CREATE INDEX", "%1 IF NOT EXISTS");
|
|
end
|
|
for i=1,#index do
|
|
sql = sql.."\""..index[i].."\"";
|
|
if i ~= #index then sql = sql..", "; end
|
|
end
|
|
sql = sql..");"
|
|
if self.params.driver == "MySQL" then
|
|
sql = sql:gsub("\"([,)])", "\"(20)%1");
|
|
end
|
|
if index.unique then
|
|
sql = sql:gsub("^CREATE", "CREATE UNIQUE");
|
|
end
|
|
if self._debug then
|
|
debugquery("create", sql);
|
|
end
|
|
return self:execute(sql);
|
|
end
|
|
function engine:_create_table(table)
|
|
local sql = "CREATE TABLE \""..table.name.."\" (";
|
|
do
|
|
sql = sql:gsub("^CREATE TABLE", "%1 IF NOT EXISTS");
|
|
end
|
|
for i,col in ipairs(table.c) do
|
|
local col_type = col.type;
|
|
if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
|
|
col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
|
|
end
|
|
if col.auto_increment == true and self.params.driver == "PostgreSQL" then
|
|
col_type = "BIGSERIAL";
|
|
end
|
|
sql = sql.."\""..col.name.."\" "..col_type;
|
|
if col.nullable == false then sql = sql.." NOT NULL"; end
|
|
if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
|
|
if col.auto_increment == true then
|
|
if self.params.driver == "MySQL" then
|
|
sql = sql.." AUTO_INCREMENT";
|
|
elseif self.params.driver == "SQLite3" then
|
|
sql = sql.." AUTOINCREMENT";
|
|
end
|
|
end
|
|
if i ~= #table.c then sql = sql..", "; end
|
|
end
|
|
sql = sql.. ");"
|
|
if self.params.driver == "MySQL" then
|
|
sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
|
|
end
|
|
if self._debug then
|
|
debugquery("create", sql);
|
|
end
|
|
local success,err = self:execute(sql);
|
|
if not success then return success,err; end
|
|
for _, v in ipairs(table.__table__) do
|
|
if is_index(v) then
|
|
self:_create_index(v);
|
|
end
|
|
end
|
|
return success;
|
|
end
|
|
function engine:set_encoding() -- to UTF-8
|
|
local driver = self.params.driver;
|
|
if driver == "SQLite3" then
|
|
return self:transaction(function()
|
|
for encoding in self:select"PRAGMA encoding;" do
|
|
if encoding[1] == "UTF-8" then
|
|
self.charset = "utf8";
|
|
end
|
|
end
|
|
end);
|
|
end
|
|
local set_names_query = "SET NAMES '%s';"
|
|
local charset = "utf8";
|
|
if driver == "MySQL" then
|
|
self:transaction(function()
|
|
for row in self:select[[
|
|
SELECT "CHARACTER_SET_NAME"
|
|
FROM "information_schema"."CHARACTER_SETS"
|
|
WHERE "CHARACTER_SET_NAME" LIKE 'utf8%'
|
|
ORDER BY MAXLEN DESC LIMIT 1;
|
|
]] do
|
|
charset = row and row[1] or charset;
|
|
end
|
|
end);
|
|
set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
|
|
end
|
|
self.charset = charset;
|
|
log("debug", "Using encoding '%s' for database connection", charset);
|
|
local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
|
|
if not ok then
|
|
return ok, err;
|
|
end
|
|
|
|
if driver == "MySQL" then
|
|
local ok, actual_charset = self:transaction(function ()
|
|
return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
|
|
end);
|
|
if not ok then
|
|
return false, "Failed to detect connection encoding";
|
|
end
|
|
local charset_ok = true;
|
|
for row in actual_charset do
|
|
if row[2] ~= charset then
|
|
log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
|
|
charset_ok = false;
|
|
end
|
|
end
|
|
if not charset_ok then
|
|
return false, "Failed to set connection encoding";
|
|
end
|
|
end
|
|
|
|
return true;
|
|
end
|
|
local engine_mt = { __index = engine };
|
|
|
|
local function db2uri(params)
|
|
return build_url{
|
|
scheme = params.driver,
|
|
user = params.username,
|
|
password = params.password,
|
|
host = params.host,
|
|
port = params.port,
|
|
path = params.database,
|
|
};
|
|
end
|
|
|
|
local function create_engine(_, params, onconnect, ondisconnect)
|
|
return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt);
|
|
end
|
|
|
|
return {
|
|
is_column = is_column;
|
|
is_index = is_index;
|
|
is_table = is_table;
|
|
is_query = is_query;
|
|
Column = Column;
|
|
Table = Table;
|
|
Index = Index;
|
|
create_engine = create_engine;
|
|
db2uri = db2uri;
|
|
};
|