prosody/util/sql.lua
Kim Alvefur 654edd37bb util.sql: SQLCipher support
This enables use of encrypted databases if LuaDBI or LuaSQLite3 has been
linked against SQLCipher. Using `LD_PRELOAD` may work as well.

Requires SQLCipher >= 4.0.0 due to the use of UPSERT
2025-01-23 16:38:56 +01:00

390 lines
11 KiB
Lua

local setmetatable, getmetatable = setmetatable, getmetatable;
local ipairs = ipairs;
local tostring = tostring;
local type = type;
local assert, pcall, debug_traceback = assert, pcall, debug.traceback;
local xpcall = require "prosody.util.xpcall".xpcall;
local t_concat = table.concat;
local log = require "prosody.util.logger".init("sql");
local DBI = require "DBI";
-- This loads all available drivers while globals are unlocked
-- LuaDBI should be fixed to not set globals.
DBI.Drivers();
local build_url = require "socket.url".build;
local _ENV = nil;
-- luacheck: std none
local column_mt = {};
local table_mt = {};
local query_mt = {};
--local op_mt = {};
local index_mt = {};
local function is_column(x) return getmetatable(x)==column_mt; end
local function is_index(x) return getmetatable(x)==index_mt; end
local function is_table(x) return getmetatable(x)==table_mt; end
local function is_query(x) return getmetatable(x)==query_mt; end
local function Column(definition)
return setmetatable(definition, column_mt);
end
local function Table(definition)
local c = {}
for i,col in ipairs(definition) do
if is_column(col) then
c[i], c[col.name] = col, col;
elseif is_index(col) then
col.table = definition.name;
end
end
return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
end
local function Index(definition)
return setmetatable(definition, index_mt);
end
function table_mt:__tostring()
local s = { 'name="'..self.__table__.name..'"' }
for _, col in ipairs(self.__table__) do
s[#s+1] = tostring(col);
end
return 'Table{ '..t_concat(s, ", ")..' }'
end
table_mt.__index = {};
function table_mt.__index:create(engine)
return engine:_create_table(self);
end
function column_mt:__tostring()
return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
end
function index_mt:__tostring()
local s = 'Index{ name="'..self.name..'"';
for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
return s..' }';
-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
end
local engine = {};
function engine:connect()
if self.conn then return true; end
local params = self.params;
assert(params.driver, "no driver")
log("debug", "Connecting to [%s] %s...", params.driver, params.database);
local ok, dbh, err = pcall(DBI.Connect,
params.driver, params.database,
params.username, params.password,
params.host, params.port
);
if not ok then return ok, dbh; end
if not dbh then return nil, err; end
dbh:autocommit(false); -- don't commit automatically
self.conn = dbh;
self.prepared = {};
if params.password then
local ok, err = self:execute(("PRAGMA key='%s'"):format(dbh:quote(params.password)));
if not ok then
return ok, err;
end
end
local ok, err = self:set_encoding();
if not ok then
return ok, err;
end
local ok, err = self:onconnect();
if ok == false then
return ok, err;
end
return true;
end
function engine:onconnect() -- luacheck: ignore 212/self
-- Override from create_engine()
end
function engine:ondisconnect() -- luacheck: ignore 212/self
-- Override from create_engine()
end
function engine:prepquery(sql)
if self.params.driver == "MySQL" then
sql = sql:gsub("\"", "`");
end
return sql;
end
function engine:execute(sql, ...)
local success, err = self:connect();
if not success then return success, err; end
local prepared = self.prepared;
sql = self:prepquery(sql);
local stmt = prepared[sql];
if not stmt then
local err;
stmt, err = self.conn:prepare(sql);
if not stmt then return stmt, err; end
prepared[sql] = stmt;
end
-- luacheck: ignore 411/success
local success, err = stmt:execute(...);
if not success then return success, err; end
return stmt;
end
local result_mt = { __index = {
affected = function(self) return self.__stmt:affected(); end;
rowcount = function(self) return self.__stmt:rowcount(); end;
} };
local function debugquery(where, sql, ...)
local i = 0; local a = {...}
sql = sql:gsub("\n?\t+", " ");
log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
i = i + 1;
local v = a[i];
if type(v) == "string" then
v = ("'%s'"):format(v:gsub("'", "''"));
end
return tostring(v);
end)));
end
function engine:execute_query(sql, ...)
sql = self:prepquery(sql);
local stmt = assert(self.conn:prepare(sql));
assert(stmt:execute(...));
local result = {};
for row in stmt:rows() do result[#result + 1] = row; end
stmt:close();
local i = 0;
return function() i=i+1; return result[i]; end;
end
function engine:execute_update(sql, ...)
sql = self:prepquery(sql);
local prepared = self.prepared;
local stmt = prepared[sql];
if not stmt then
stmt = assert(self.conn:prepare(sql));
prepared[sql] = stmt;
end
assert(stmt:execute(...));
return setmetatable({ __stmt = stmt }, result_mt);
end
engine.insert = engine.execute_update;
engine.select = engine.execute_query;
engine.delete = engine.execute_update;
engine.update = engine.execute_update;
local function debugwrap(name, f)
return function (self, sql, ...)
debugquery(name, sql, ...)
return f(self, sql, ...)
end
end
function engine:debug(enable)
self._debug = enable;
if enable then
engine.insert = debugwrap("insert", engine.execute_update);
engine.select = debugwrap("select", engine.execute_query);
engine.delete = debugwrap("delete", engine.execute_update);
engine.update = debugwrap("update", engine.execute_update);
else
engine.insert = engine.execute_update;
engine.select = engine.execute_query;
engine.delete = engine.execute_update;
engine.update = engine.execute_update;
end
end
local function handleerr(err)
local trace = debug_traceback(err, 3);
log("debug", "Error in SQL transaction: %s", trace);
return { err = err, traceback = trace };
end
function engine:_transaction(func, ...)
if not self.conn then
local ok, err = self:connect();
if not ok then return ok, err; end
end
--assert(not self.__transaction, "Recursive transactions not allowed");
log("debug", "SQL transaction begin [%s]", func);
self.__transaction = true;
local success, a, b, c = xpcall(func, handleerr, ...);
self.__transaction = nil;
if success then
log("debug", "SQL transaction success [%s]", func);
local ok, err = self.conn:commit();
-- LuaDBI doesn't actually return an error message here, just a boolean
if not ok then return ok, err or "commit failed"; end
return success, a, b, c;
else
log("debug", "SQL transaction failure [%s]: %s", func, a.err);
if self.conn then self.conn:rollback(); end
return success, a.err;
end
end
function engine:transaction(...)
local ok, ret, b, c = self:_transaction(...);
if not ok then
local conn = self.conn;
if not conn or not conn:ping() then
log("debug", "Database connection was closed. Will reconnect and retry.");
self.conn = nil;
self:ondisconnect();
log("debug", "Retrying SQL transaction [%s]", (...));
ok, ret, b, c = self:_transaction(...);
log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed");
else
log("debug", "SQL connection is up, so not retrying");
end
if not ok then
log("error", "Error in SQL transaction: %s", ret);
end
end
return ok, ret, b, c;
end
function engine:_create_index(index)
local sql = "CREATE INDEX \""..index.name.."\" ON \""..index.table.."\" (";
if self.params.driver ~= "MySQL" then
sql = sql:gsub("^CREATE INDEX", "%1 IF NOT EXISTS");
end
for i=1,#index do
sql = sql.."\""..index[i].."\"";
if i ~= #index then sql = sql..", "; end
end
sql = sql..");"
if self.params.driver == "MySQL" then
sql = sql:gsub("\"([,)])", "\"(20)%1");
end
if index.unique then
sql = sql:gsub("^CREATE", "CREATE UNIQUE");
end
if self._debug then
debugquery("create", sql);
end
return self:execute(sql);
end
function engine:_create_table(table)
local sql = "CREATE TABLE \""..table.name.."\" (";
do
sql = sql:gsub("^CREATE TABLE", "%1 IF NOT EXISTS");
end
for i,col in ipairs(table.c) do
local col_type = col.type;
if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
end
if col.auto_increment == true and self.params.driver == "PostgreSQL" then
col_type = "BIGSERIAL";
end
sql = sql.."\""..col.name.."\" "..col_type;
if col.nullable == false then sql = sql.." NOT NULL"; end
if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
if col.auto_increment == true then
if self.params.driver == "MySQL" then
sql = sql.." AUTO_INCREMENT";
elseif self.params.driver == "SQLite3" then
sql = sql.." AUTOINCREMENT";
end
end
if i ~= #table.c then sql = sql..", "; end
end
sql = sql.. ");"
if self.params.driver == "MySQL" then
sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
end
if self._debug then
debugquery("create", sql);
end
local success,err = self:execute(sql);
if not success then return success,err; end
for _, v in ipairs(table.__table__) do
if is_index(v) then
self:_create_index(v);
end
end
return success;
end
function engine:set_encoding() -- to UTF-8
local driver = self.params.driver;
if driver == "SQLite3" then
return self:transaction(function()
for encoding in self:select"PRAGMA encoding;" do
if encoding[1] == "UTF-8" then
self.charset = "utf8";
end
end
end);
end
local set_names_query = "SET NAMES '%s';"
local charset = "utf8";
if driver == "MySQL" then
self:transaction(function()
for row in self:select[[
SELECT "CHARACTER_SET_NAME"
FROM "information_schema"."CHARACTER_SETS"
WHERE "CHARACTER_SET_NAME" LIKE 'utf8%'
ORDER BY MAXLEN DESC LIMIT 1;
]] do
charset = row and row[1] or charset;
end
end);
set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
end
self.charset = charset;
log("debug", "Using encoding '%s' for database connection", charset);
local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
if not ok then
return ok, err;
end
if driver == "MySQL" then
local ok, actual_charset = self:transaction(function ()
return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
end);
if not ok then
return false, "Failed to detect connection encoding";
end
local charset_ok = true;
for row in actual_charset do
if row[2] ~= charset then
log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
charset_ok = false;
end
end
if not charset_ok then
return false, "Failed to set connection encoding";
end
end
return true;
end
local engine_mt = { __index = engine };
local function db2uri(params)
return build_url{
scheme = params.driver,
user = params.username,
password = params.password,
host = params.host,
port = params.port,
path = params.database,
};
end
local function create_engine(_, params, onconnect, ondisconnect)
return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt);
end
return {
is_column = is_column;
is_index = is_index;
is_table = is_table;
is_query = is_query;
Column = Column;
Table = Table;
Index = Index;
create_engine = create_engine;
db2uri = db2uri;
};