diff --git a/lua/sqlite/assert.lua b/lua/sqlite/assert.lua index d15f5734..83f8fc45 100644 --- a/lua/sqlite/assert.lua +++ b/lua/sqlite/assert.lua @@ -13,6 +13,8 @@ local errors = { missing_db_object = "%s's db object is not set. set it with `%s:set_db(db)` and try again.", outdated_schema = "`%s` does not exists in {`%s`}, schema is outdateset `self.db.tbl_schemas[table_name]` or reload", auto_alter_more_less_keys = "schema defined ~= db schema. Please drop `%s` table first or set ensure to false.", + miss_match_pk_type = "Primary key ('%s') is of type '%s', '%s' can't be used to access %s table. RECEIVED KEY: %s", + no_primary_key = "%s has no primary key.", } for key, value in pairs(errors) do @@ -80,4 +82,22 @@ M.auto_alter_should_have_equal_len = function(len_new, len_old, tname) end end +M.should_match_pk_type = function(name, kt, pk, key) + local knotstr = kt ~= "string" + local knotnum = kt ~= "number" + local pt = pk.type + if not pk or not pk.type then + return error(errors.no_primary_key:format(name)) + end + + if + kt ~= "boolean" + and (knotstr and (pt == "string" or pt == "text") or knotnum and (pt == "number" or pt == "integer")) + then + return error(errors.miss_match_pk_type:format(pk.name, pk.type, kt, name, key)) + end + + return true +end + return M diff --git a/lua/sqlite/db.lua b/lua/sqlite/db.lua index e0728723..74de6b31 100644 --- a/lua/sqlite/db.lua +++ b/lua/sqlite/db.lua @@ -591,15 +591,25 @@ function sqlite.db:select(tbl_name, spec, schema) spec = spec or {} spec.select = spec.keys and spec.keys or spec.select + local select = p.select(tbl_name, spec) + local st = "" - local stmt = s:parse(self.conn, p.select(tbl_name, spec)) - s.each(stmt, function() - table.insert(ret, s.kv(stmt)) + local stmt = s:parse(self.conn, select, tbl_name) + stmt:each(function() + table.insert(ret, stmt:kv()) end) - s.reset(stmt) - if s.finalize(stmt) then + if tbl_name == "todos_indexer" then + st = stmt:expand() + end + + stmt:reset() + if stmt:finalize() then self.modified = false end + if tbl_name == "todos_indexer" and spec.id == 3 then + error(st) + end + return p.post_select(ret, schema) end) end @@ -643,6 +653,10 @@ function sqlite.db:table(tbl_name, opts) return self:tbl(tbl_name, opts) end +function sqlite.db:last_insert_rowid() + return tonumber(clib.last_insert_rowid(self.conn)) +end + ---Sqlite functions sugar wrappers. See `sql/strfun` sqlite.db.lib = require "sqlite.strfun" diff --git a/lua/sqlite/defs.lua b/lua/sqlite/defs.lua index d9fabcc6..fc24ccce 100644 --- a/lua/sqlite/defs.lua +++ b/lua/sqlite/defs.lua @@ -588,7 +588,7 @@ ffi.cdef [[ ]] ---@class sqlite3 @sqlite3 db object ----@class sqlite_blob @sqlite3 blob object +---@class sqlite_blob* @sqlite3 blob object M.to_str = function(ptr, len) if ptr == nil then diff --git a/lua/sqlite/helpers.lua b/lua/sqlite/helpers.lua index 4be35b6c..74e0d758 100644 --- a/lua/sqlite/helpers.lua +++ b/lua/sqlite/helpers.lua @@ -66,6 +66,7 @@ end ---@param o sqlite_tbl ---@return any M.run = function(func, o) + func = func or function() end a.should_have_db_object(o.db, o.name) local exec = function() local valid_schema = o.tbl_schema and next(o.tbl_schema) ~= nil @@ -95,6 +96,7 @@ M.run = function(func, o) o.db_schema = o.db:schema(o.name) end + rawset(o, "last_id", o.db:last_insert_rowid()) --- Run wrapped function return func() end diff --git a/lua/sqlite/stmt.lua b/lua/sqlite/stmt.lua index 6277a16a..9b0dec10 100644 --- a/lua/sqlite/stmt.lua +++ b/lua/sqlite/stmt.lua @@ -34,10 +34,17 @@ function sqlstmt:parse(conn, str) assert( code == flags.ok, - ("sqlite.lua: sql statement parse, , stmt: `%s`, err: `(`%s`)`"):format(o.str, clib.last_errmsg(o.conn)) + string.format( + "sqlite.lua\n(parse error): `%s` code == %d\nstatement == '%s'", + clib.to_str(clib.errmsg(self.conn)), + code, + self.str + ) ) + o.pstmt = pstmt[0] return o + end ---Resets the parsed statement. required for parsed statements to be re-executed. diff --git a/lua/sqlite/tbl.lua b/lua/sqlite/tbl.lua index 12e4f4a9..16d7b83a 100644 --- a/lua/sqlite/tbl.lua +++ b/lua/sqlite/tbl.lua @@ -11,6 +11,7 @@ local u = require "sqlite.utils" local h = require "sqlite.helpers" +local indexer = require "sqlite.tbl.indexer" local sqlite = {} ---@class sqlite_tbl @Main sql table class @@ -20,6 +21,8 @@ local sqlite = {} sqlite.tbl = {} sqlite.tbl.__index = sqlite.tbl +-- TODO: Add examples to index access in sqlite.tbl.new + ---Create new |sqlite_tbl| object. This object encouraged to be extend and ---modified by the user. overwritten method can be still accessed via ---pre-appending `__` e.g. redefining |sqlite_tbl:get|, result in @@ -52,27 +55,19 @@ sqlite.tbl.__index = sqlite.tbl ---@return sqlite_tbl function sqlite.tbl.new(name, schema, db) schema = schema or {} - - local t = setmetatable({ + schema = u.if_nil(schema.schema, schema) + ---@type sqlite_tbl + local tbl = setmetatable({ db = db, name = name, - tbl_schema = u.if_nil(schema.schema, schema), + tbl_schema = schema, }, sqlite.tbl) if db then - h.run(function() end, t) + h.run(nil, tbl) end - return setmetatable({}, { - __index = function(_, key, ...) - if type(key) == "string" then - key = key:sub(1, 2) == "__" and key:sub(3, -1) or key - if t[key] then - return t[key] - end - end - end, - }) + return indexer(tbl) end ---Create or change table schema. If no {schema} is given, @@ -468,6 +463,12 @@ function sqlite.tbl:set_db(db) self.db = db end +function sqlite.tbl:last_id() + h.run(function() + rawset(self, "last_id", self.db:last_insert_rowid()) + end, self) +end + sqlite.tbl = setmetatable(sqlite.tbl, { __call = function(_, ...) return sqlite.tbl.new(...) diff --git a/lua/sqlite/tbl/extend.lua b/lua/sqlite/tbl/extend.lua deleted file mode 100644 index b9a696f0..00000000 --- a/lua/sqlite/tbl/extend.lua +++ /dev/null @@ -1,113 +0,0 @@ ----@type sqlite_tblext -local tbl = {} - ----Create or change table schema. If no {schema} is given, ----then it return current the used schema if it exists or empty table otherwise. ----On change schema it returns boolean indecting success. ----@param schema table: table schema definition ----@return table table | boolean ----@usage `tbl.schema()` get project table schema. ----@usage `tbl.schema({...})` mutate project table schema ----@todo do alter when updating the schema instead of droping it completely -tbl.schema = function(schema) end - ----Remove table from database, if the table is already drooped then it returns false. ----@usage `todos:drop()` drop todos table content. ----@see DB:drop ----@return boolean -tbl.drop = function() end - ----Predicate that returns true if the table is empty. ----@usage `if todos:empty() then echo "no more todos, you are free :D" end` ----@return boolean -tbl.empty = function() end - ----Predicate that returns true if the table exists. ----@usage `if not goals:exists() then error("I'm disappointed in you ") end` ----@return boolean -tbl.exists = function() end - ----Query the table and return results. ----@param query sqlite_query_select ----@return table ----@usage `tbl.get()` get a list of all rows in project table. ----@usage `tbl.get({ where = { status = "pending", client = "neovim" }})` ----@usage `tbl.get({ where = { status = "done" }, limit = 5})` get the last 5 done projects ----@see DB:select -tbl.get = function(query) end - ----Get the current number of rows in the table ----@return number -tbl.count = function() end - ----Get first match. ----@param where table: where key values ----@return nil or row ----@usage `tbl.where{id = 1}` ----@see DB:select -tbl.where = function(where) end - ----Iterate over table rows and execute {func}. ----Returns true only when rows is not emtpy. ----@param func function: func(row) ----@param query sqlite_query_select ----@usage `let query = { where = { status = "pending"}, contains = { title = "fix*" } }` ----@usage `tbl.each(function(row) print(row.title) end, query)` ----@return boolean -tbl.each = function(func, query) end - ----Create a new table from iterating over {self.name} rows with {func}. ----@param func function: func(row) ----@param query sqlite_query_select ----@usage `let query = { where = { status = "pending"}, contains = { title = "fix*" } }` ----@usage `local t = todos.map(function(row) return row.title end, query)` ----@return table[] -tbl.map = function(func, query) end - ----Sorts a table in-place using a transform. Values are ranked in a custom order of the results of ----running `transform (v)` on all values. `transform` may also be a string name property sort by. ----`comp` is a comparison function. Adopted from Moses.lua ----@param query sqlite_query_select ----@param transform function: a `transform` function to sort elements. Defaults to @{identity} ----@param comp function: a comparison function, defaults to the `<` operator ----@return table[] ----@usage `local res = tbl.sort({ where = {id = {32,12,35}}})` return rows sort by id ----@usage `local res = tbl.sort({ where = {id = {32,12,35}}}, "age")` return rows sort by age ----@usage `local res = tbl.sort({where = { ... }}, "age", function(a, b) return a > b end)` with custom function -tbl.sort = function(query, transform, comp) end - ----Same functionalities as |DB:insert()| ----@param rows table: a row or a group of rows ----@see DB:insert ----@usage `tbl.insert { title = "stop writing examples :D" }` insert single item. ----@usage `tbl.insert { { ... }, { ... } }` insert multiple items ----@return integer: last inserted id -tbl.insert = function(rows) end - ----Same functionalities as |DB:delete()| ----@param where sqlite_query_delete: key value pairs to filter rows to delete ----@see DB:delete ----@return boolean ----@usage `todos.remove()` remove todos table content. ----@usage `todos.remove{ project = "neovim" }` remove all todos where project == "neovim". ----@usage `todos.remove{{project = "neovim"}, {id = 1}}` remove all todos where project == "neovim" or id =1 -tbl.remove = function(where) end - ----Same functionalities as |DB:update()| ----@param specs sqlite_query_update ----@see DB:update ----@return boolean -tbl.update = function(specs) end - ----replaces table content with {rows} ----@param rows table: a row or a group of rows ----@see DB:delete ----@see DB:insert ----@return boolean -tbl.replace = function(rows) end - ----Set db object for the table. ----@param db sqlite_db -tbl.set_db = function(db) end - -return tbl diff --git a/lua/sqlite/tbl/indexer.lua b/lua/sqlite/tbl/indexer.lua new file mode 100644 index 00000000..f3e74448 --- /dev/null +++ b/lua/sqlite/tbl/indexer.lua @@ -0,0 +1,168 @@ +local a = require "sqlite.assert" +local u = require "sqlite.utils" +---Get primary key from sqlite_tbl schema +---@param tbl sqlite_schema_dict +local get_primary_key = function(tbl) + local pk + for k, v in pairs(tbl or {}) do + if type(v) == "table" and v.primary then + pk = v + if not pk.type then + pk.type = pk[1] + end + pk.name = k + break + elseif type(v) == "boolean" and k ~= "ensure" then + pk = { + name = k, + type = "integer", + primary = true, + } + break + end + end + + return pk +end + +---TODO: remove after using value should rawset and rawget in helpers.run, sqlite.tbl +local is_part_of_tbl_object = function(key) + return key == "db" + or key == "tbl_exists" + or key == "db_schema" + or key == "tbl_name" + or key == "tbl_schema" + or key == "mtime" + or key == "has_content" + or key == "_name" +end + +local resolve_meta_key = function(kt, key) + return kt == "string" and (key:sub(1, 2) == "__" and key:sub(3, -1) or key) or key +end + +---Used to extend a table row to be able to manipulate the row values +---@param tbl sqlite_tbl +---@param pk sqlite_schema_key +---@return function(row, altkey):table +local tbl_row_extender = function(tbl, pk) + return function(row, reqkey) + row = row or {} + local mt = { + __newindex = function(_, key, val) + tbl:update { -- TODO: maybe return inserted row?? + where = { [pk.name] = (row[pk.name] or reqkey) }, + set = { [key] = val }, + } + row = {} + end, + __index = function(_, key) + if key == "values" then + return row + end + if not row[key] then + local res = tbl:where { [pk.name] = row[pk.name] } or {} + row = res + return res[key] + else + return row[key] + end + end, + } + + return setmetatable({}, mt) + end +end + +local sep_query_and_where = function(q, keys) + local kv = { where = {} } + for k, v in pairs(q) do + if u.contains(keys, k) then + kv.where[k] = v + else + kv[k] = v + end + end + + if next(kv.where) == nil then + kv.where = nil + end + return kv +end + +---Print errors to the user +---@param func function +-- local sc = function(func) +-- local ok, val = xpcall(func, function(msg) +-- print(msg) +-- end) +-- return ok and val +-- end + +return function(tbl) + local pk = get_primary_key(tbl.tbl_schema) + local extend = tbl_row_extender(tbl, pk) + local tbl_keys = u.keys(tbl.tbl_schema) + local mt = {} + + mt.__index = function(_, arg) + if is_part_of_tbl_object(arg) then + return tbl[arg] + end + + local kt = type(arg) + local skey = resolve_meta_key(kt, arg) + + if not pk or (kt == "string" and tbl[skey]) then + return tbl[skey] + end + + if kt == "string" or kt == "number" and pk then + a.should_match_pk_type(tbl.name, kt, pk, arg) + return extend( + tbl:where { + [pk.name] = arg, + }, + arg + ) + end + + return kt == "table" and tbl:get(sep_query_and_where(arg, tbl_keys)) + end + + mt.__newindex = function(o, arg, val) + if is_part_of_tbl_object(arg) then + tbl[arg] = val + return + end + + local kt, vt = type(arg), type(val) + + if not pk or (vt == "function" and kt == "string") then + rawset(o, arg, val) + return + end + + if vt == "nil" and kt == "string" or kt == "number" then + tbl:remove { [pk.name] = arg } + end + + if vt == "table" and kt == "table" then + local q = sep_query_and_where(arg, tbl_keys) + q.set = val + tbl:update(q) + return + end + + if vt == "table" and pk then + a.should_match_pk_type(tbl.name, kt, pk, arg) + if arg == 0 or arg == true or arg == "" then + return tbl:insert(val) + else + return tbl:update { where = { [pk.name] = arg }, set = val } + end + end + end + + return setmetatable({ _config = {}, _state = {} }, mt) +end diff --git a/lua/sqlite/utils.lua b/lua/sqlite/utils.lua index d35aeabf..2f149208 100644 --- a/lua/sqlite/utils.lua +++ b/lua/sqlite/utils.lua @@ -59,6 +59,15 @@ M.okeys = function(t) return r end +M.contains = function(tbl, val) + for _, v in ipairs(tbl) do + if v == val then + return true + end + end + return false +end + M.opairs = (function() local __gen_order_index = function(t) local orderedIndex = {} diff --git a/test/auto/tbl_spec.lua b/test/auto/tbl_spec.lua index b756782a..16752cfa 100644 --- a/test/auto/tbl_spec.lua +++ b/test/auto/tbl_spec.lua @@ -972,5 +972,160 @@ describe("sqlite.tbl", function() end) end) end) + + describe(":index access", function() + local db_path = ":memory:" or "/tmp/idx_db" + -- vim.loop.fs_unlink(db_path) + db = sql:open(db_path) + + describe("string_index:", function() + local kv = tbl("kvpair", { + key = { "text", primary = true, required = true, default = "none" }, + len = "integer", + }, db) + + it("access/insert-to table using primary key", function() + kv.a = { len = 1 } + eq({ key = "a", len = 1 }, kv.a) + end) + + it("access/update a row field len", function() + -- eq({}, kv.a) + kv.a = { len = 1 } + kv.a.len = 2 + eq(2, kv:where({ len = 2 }).len, "should have been set") + eq(2, kv.a.len, "should have been set") + kv.a.len = 3 + eq({ key = "a", len = 3 }, kv.a, "should return values") + end) + + it("remove a row using primary key", function() + kv.a = nil + eq(nil, kv:where { key = "a" }, "should be empty") + eq({}, kv.a, "should be empty") + end) + + it("sets a row field len without creating the row first", function() + kv["some key with spaces :D"].len = 4 + eq(kv["some key with spaces :D"], { key = "some key with spaces :D", len = 4 }) + kv["some key with spaces :D"] = nil + end) + + it("query using index", function() + kv.a.len, kv.b.len, kv.c.len = 1, 2, 3 + eq( + { + { key = "a", len = 1 }, + { key = "b", len = 2 }, + }, + kv[{ + where = { len = { 1, 2, 3 } }, + order_by = { asc = { "key", "len" } }, + limit = 2, + }] + ) + end) + it("bulk update", function() + kv[{ len = { 1, 2, 3 } }] = { len = 10 } + eq( + { + { key = "a", len = 10 }, + { key = "b", len = 10 }, + }, + kv[{ + order_by = { asc = { "key" } }, + limit = 2, + }] + ) + end) + + it("insert with 0 or true to skip the primary key value.", function() + kv[true] = { len = 5 } + eq(5, kv.none.len) + kv[""] = { len = 6 } + eq({ key = "none", len = 6 }, kv:where { len = 6 }) + end) + end) + + describe("number_index", function() + local t = tbl("number_idx", { id = true, name = "integer" }, db) + + it("passes string_index tests", function() + t[1] = { name = "sam" } + eq({ id = 1, name = "sam" }, t[1]) + eq("sam", t:where({ id = 1 }).name, "should have been set") + + t[2].name = "John" + eq({ id = 2, name = "John" }, t[2]) + eq("John", t:where({ id = 2 }).name, "should have been set") + eq("John", t[2].name, "should have been set") + + t[2] = nil + eq(nil, t:where { id = 2 }, "should be empty") + eq({}, t[2], "should be empty") + + t[1].name, t[2].name, t[2].name = "sam", "tami", "ram" + eq( + { + { id = 1, name = "sam" }, + { id = 2, name = "tami" }, + }, + t[{ + where = { name = { "sam", "tami", "ram" } }, + order_by = { asc = { "id" } }, + limit = 2, + }] + ) + t[{ id = { 1, 2, 3 } }] = { name = "none" } + eq( + { + { id = 1, name = "none" }, + { id = 2, name = "none" }, + }, + t[{ + order_by = { asc = { "id" } }, + limit = 2, + }] + ) + end) + end) + + describe("Relationships", function() + local todos = tbl("todos_indexer", { + id = true, + title = "text", + project = { + reference = "projects.title", + required = true, + on_delete = "cascade", + on_update = "cascade", + }, + }, db) + + local projects = tbl("projects", { + title = { type = "text", primary = true, required = true, unique = true }, + deadline = { "date", default = db.lib.date "now" }, + }, db) + + it("create new table with default values", function() + projects.neovim = {} + eq(true, projects.neovim.deadline == os.date "!%Y-%m-%d") + projects["sqlite"] = {} + --- TODO: if you have sqilte.lua todos[2] return empty table + end) + + it("fails if foregin key doesn't exists", function() + eq( + false, + pcall(function() + todos[2].project = "ram" + end) + ) + end) + end) + + -- vim.loop.fs_unlink(db_path) + end) + clean() end)