From b5c2a0c077c503295b76ba6b147841076a3a7c0b Mon Sep 17 00:00:00 2001 From: George Lemon Date: Thu, 25 Apr 2024 04:37:22 +0300 Subject: [PATCH] safe SqlQuery | add `limit` + `exists` Signed-off-by: George Lemon --- src/enimsql/model.nim | 18 +++-- src/enimsql/private/query.nim | 140 ++++++++++++++++++++++++---------- 2 files changed, 112 insertions(+), 46 deletions(-) diff --git a/src/enimsql/model.nim b/src/enimsql/model.nim index aeb2963..d91be29 100644 --- a/src/enimsql/model.nim +++ b/src/enimsql/model.nim @@ -290,7 +290,8 @@ macro initTable*(model: untyped) = var execCall = newCall(ident"exec", ident"dbcon") sqlCall = newCall ident"sql" - add sqlCall, newLit sql(createStmt, table) + var values: seq[string] + add sqlCall, newLit sql(createStmt, table, values) add execCall, sqlCall add result, execCall else: @@ -310,7 +311,8 @@ macro tryCreate*[M](model: typedesc[M], then: untyped) = var execCall = newCall(ident"tryExec", ident"dbcon") sqlCall = newCall ident"sql" - add sqlCall, newLit sql(createStmt, table) + var values: seq[string] + add sqlCall, newLit sql(createStmt, table, values) add execCall, sqlCall add result, nnkBlockStmt.newTree( @@ -330,7 +332,8 @@ macro delete*[M](model: typedesc[M]) = result = newStmtList() let table = getTableName($model) var dropStmt: Query = newDropStmt() - exec sql(dropStmt, table) + var values: seq[string] + exec sql(dropStmt, table, values) macro delete*(model: untyped, then: untyped) = ## Delete a table represented by `Model` @@ -338,7 +341,8 @@ macro delete*(model: untyped, then: untyped) = result = newStmtList() let table = getTableName($model) var dropStmt: Query = newDropStmt() - tryExec sql(dropStmt, table) + var values: seq[string] + tryExec sql(dropStmt, table, values) template drop*(model: untyped) = ## Delete a table represented by `Model` @@ -350,7 +354,8 @@ macro clear*(model: untyped, then: untyped) = result = newStmtList() let table = getTableName($model) var clearStmt: Query = newClearStmt() - executeSQL sql(clearStmt, table): + var values: seq[string] + executeSQL sql(clearStmt, table, values): then macro insertRow*(model: untyped, row: untyped, @@ -380,7 +385,8 @@ macro insertRow*(model: untyped, row: untyped, var callExec = newCall(ident"tryInsertID", ident"dbcon") callSql = newCall ident"sql" - add callSql, newLit sql(insertStmt, table) + sqlValues: seq[string] + add callSql, newLit sql(insertStmt, table, sqlValues) add callExec, callSql for v in values: add callExec, v diff --git a/src/enimsql/private/query.nim b/src/enimsql/private/query.nim index 38d0a40..87755b7 100644 --- a/src/enimsql/private/query.nim +++ b/src/enimsql/private/query.nim @@ -18,6 +18,7 @@ type ntWhere ntWhereLike ntWhereExistsStmt + ntOr ntInc ntDec ntSet @@ -26,6 +27,7 @@ type ntInfix ntComp ntField + ntLimit SQLOperator* = enum EQ = "=" @@ -81,6 +83,7 @@ type selectTable*: string selectCondition*: Query # ntWhere selectOrder*: seq[(string, Order)] + selectQueryFilters*: seq[Query] of ntInsert, ntUpsert: insertFields*: OrderedTable[string, string] insertReturn*: Query # ntReturn node @@ -88,10 +91,14 @@ type infixOp: SQLOperator infixLeft, infixRight: string of ntWhere: - whereBranches*: seq[Query] # ntInfix + whereBranches*: seq[Query] # ntInfix + of ntOr: + orBranches*: seq[Query] # ntInfix of ntUpdate, ntUpdateAll: updateFields: seq[(SQLColumn, string)] updateCondition*: Query + of ntLimit: + limitNumber: int of ntDrop: discard of ntTruncate: @@ -149,6 +156,8 @@ proc newWhereStmt*: Query = Query(nt: ntWhere) proc newSelectStmt*: Query = Query(nt: ntSelect) proc newUpdateStmt*: Query = Query(nt: ntUpdate) proc newUpdateAllStmt*: Query = Query(nt: ntUpdateAll) +proc newOrStmt*: Query = Query(nt: ntOr) +proc newLimitFilter*(i: int): Query = Query(nt: ntLimit, limitNumber: i) proc newInfixExpr*(lhs, rhs: string, op: SQLOperator): Query = result = Query(nt: ntInfix) @@ -217,7 +226,7 @@ proc q*(key: string): string {.compileTime.} = else: stmtSqlite[key] -proc sql*(node: Query, k: string): string = +proc sql*(node: Query, k: string, values: var seq[string]): string = ## Transform given SQL Query to stringified SQL case node.nt: of ntCreate: @@ -264,33 +273,54 @@ proc sql*(node: Query, k: string): string = else: add result, node.selectColumns.mapIt(indent(it, 1)).join(",") add result, indent("FROM", 1) - add result, indent(node.selectTable, 1) + add result, node.selectTable.escape.indent(1) if node.selectCondition != nil: - add result, sql(node.selectCondition, k) + add result, sql(node.selectCondition, k, values) if node.selectOrder.len > 0: add result, indent(q("orderby") % join(node.selectOrder.mapIt(it[0] & indent($(it[1]), 1)), ","), 1) + for filter in node.selectQueryFilters: + add result, indent(sql(filter, k, values), 1) of ntWhere: for branch in node.whereBranches: - add result, q("where").indent(1) - var val: string - when nimvm: - val = - case StaticSchema[k].tColumns[branch.infixLeft].cType - of Boolean, Int, Numeric, Money, Serial: - branch.infixRight - else: - "'" & branch.infixRight & "'" - else: - val = - case Models[k].tColumns[branch.infixLeft].cType - of Boolean, Int, Numeric, Money, Serial: - branch.infixRight - else: - "'" & branch.infixRight & "'" - add result, indent(branch.infixLeft, 1) - add result, indent($branch.infixOp, 1) - add result, indent(val, 1) + case branch.nt + of ntInfix: + add result, q("where").indent(1) + # var val: string + # when nimvm: + # val = + # case StaticSchema[k].tColumns[branch.infixLeft].cType + # of Boolean, Int, Numeric, Money, Serial: + # # branch.infixRight + # "'" & branch.infixRight & "'" + # else: + # "'" & branch.infixRight & "'" + # else: + # val = + # case Models[k].tColumns[branch.infixLeft].cType + # of Boolean, Int, Numeric, Money, Serial: + # "'" & branch.infixRight & "'" + # else: + # "'" & branch.infixRight & "'" + add result, branch.infixLeft.escape.indent(1) + add result, indent($(branch.infixOp), 1) + # add result, branch.infixRight.escape(prefix = "'", suffix = "'").indent(1) + add result, indent("?", 1) + add values, branch.infixRight + of ntOr: + add result, sql(branch, k, values) + else: discard + of ntOr: + result = indent($OR, 1) + for branch in node.orBranches: + case branch.nt + of ntInfix: + add result, indent(branch.infixLeft.escape, 1) + add result, indent($branch.infixOp, 1) + add result, indent("?", 1) + add values, branch.infixRight + else: discard # todo + # add result, sql(node, k) of ntUpdate, ntUpdateAll: result = q("update") % k var updates: seq[string] @@ -306,7 +336,7 @@ proc sql*(node: Query, k: string): string = case node.nt of ntUpdate: if likely(node.updateCondition != nil): - add result, sql(node.updateCondition, k) + add result, sql(node.updateCondition, k, values) else: discard # ntUpdateAll of ntDrop: result = q("drop") % k @@ -321,27 +351,30 @@ proc sql*(node: Query, k: string): string = let total = if node.insertFields.len == 0: 0 else: node.insertFields.len - 1 - var cols, values = indent("(", 1) + var cols, vals = indent("(", 1) for k, v in node.insertFields: add cols, k - add values, "?" + add vals, "?" + add values, v if i != total: add cols, "," & spaces(1) - add values, "," & spaces(1) + add vals, "," & spaces(1) inc i add cols, indent(")", 0) - add values, indent(")", 0) + add vals, indent(")", 0) add result, cols add result, indent("VALUES", 1) - add result, values + add result, vals # case node.nt # of ntUpsert: # add result, "ON CONFLICT($1)" % "id" # else: discard setLen(cols, 0) - setLen(values, 0) + setLen(vals, 0) # if node.insertReturn != nil: # add result, sql(node.insertReturn, k) + of ntLimit: + result = indent("LIMIT " & $(node.limitNumber), 1) of ntReturn: result = indent(q("returning") % [node.returnColName], 1) else: discard @@ -377,7 +410,8 @@ proc create*(models: SchemaTable, for k, col in schemaTable: add createStmt.createColumns, col if createStmt.createColumns.len != 0: - result = SQLQuery sql(createStmt, tableName) + var values: seq[string] + result = SQLQuery sql(createStmt, tableName, values) models[modelName] = (tableName, schemaTable) proc add*(schema: Schema, name: string, @@ -475,9 +509,11 @@ proc where*(q: QueryBuilder, key: string, q[1].selectCondition = whereStmt of ntUpdate: q[1].updateCondition = whereStmt + of ntOr: + add q[1].orBranches, whereStmt.whereBranches else: raise newException(EnimsqlQueryDefect, - "Invalid use of where statement for " & $q[1].nt) + "Invalid use of `WHERE` statement for " & $q[1].nt) proc where*(q: QueryBuilder, key, val: string): QueryBuilder {.discardable.} = result = q.where(key, EQ, val) @@ -485,7 +521,9 @@ proc where*(q: QueryBuilder, key, val: string): QueryBuilder {.discardable.} = proc orWhere*(q: QueryBuilder, handle: proc(q: QueryBuilder)): QueryBuilder = ## Use the `orWhere` proc to join a clause to the ## query using the `or` operator - handle(q) + var q2 = (q[0], newOrStmt()) + handle(q2) + add q[1].selectCondition.whereBranches, q2[1] result = q proc update*(model: Model, @@ -534,11 +572,27 @@ proc orderBy*(q: QueryBuilder, key: string, order: Order = Order.ASC): QueryBuil checkColumn key: add q[1].selectOrder, (key, order) +proc limit*(q: QueryBuilder, i: int): QueryBuilder = + ## Add a `LIMIT` filter to current `QueryBuilder` + add q[1].selectQueryFilters, newLimitFilter(i) + result = q + +template exists*(m: Model, colName, expectValue: string): bool = + ## Search in the current table if `colName` contains `expectValue` + # https://stackoverflow.com/questions/8149596/check-value-if-exists-in-column + # todo a nice way to check entries by value + var values: seq[string] + let x = m.select.where(colName, expectValue).limit(1) + let q = sql(x[1], m.tName, values) + let res = dbcon.getRow(SQLQuery("SELECT EXISTS (" & q & ")"), values) + len(res) > 0 + template getAll*(q: QueryBuilder): untyped = ## Execute the query and returns a `Collection` ## instance with the available results - var rows = getAllRows(dbcon, - SQLQuery(sql(q[1], q[0].tName))) + var values: seq[string] + let x = SQLQuery(sql(q[1], q[0].tName, values)) + var rows = getAllRows(dbcon, x, values) var results = initCollection[SQLValue]() if rows.len > 0: if q[1].selectColumns.len > 0: @@ -585,8 +639,9 @@ macro `@`*(x: untyped): untyped = template getAll*(q: QueryBuilder, T: typedesc): untyped = ## Execute the query and returns a collection of objects `T`. ## This works only for a `Model` defined at compile-time - let results = getAllRows(dbcon, - SQLQuery(sql(q[1], q[0].tName))) + var values: seq[string] + let x = SQLQuery(sql(q[1], q[0].tName)) + let results = getAllRows(dbcon, x, values) var collections: seq[T] for res in results: add collections, initModel(T, res) @@ -594,19 +649,24 @@ template getAll*(q: QueryBuilder, T: typedesc): untyped = template exec*(q: QueryBuilder): untyped = ## Use it inside a `withDB` context to execute a query + var values: seq[string] case q[1].nt of ntUpdate: assert q[1].updateCondition != nil - dbcon.exec(SQLQuery sql(q[1], q[0].tName)) + let x = SQLQuery(sql(q[1], q[0].tName, values)) + dbcon.exec(x, values) of ntInsert: - dbcon.exec(SQLQuery sql(q[1], q[0].tName), q[1].insertFields.values.toSeq) + let x = SQLQuery(sql(q[1], q[0].tName, values)) + dbcon.exec(x, values) # q[1].insertFields.values.toSeq else: discard # todo other final checks before executing the query template execGet*(q: QueryBuilder, pk = "id"): untyped = ## Use it inside a `withDB` context to execute an `INSERT` query. ## Use a different `pk` name if your primary key is not named `id`. + var values: seq[string] assert q[1].nt == ntInsert - dbcon.tryInsert(SQLQuery sql(q[1], q[0].tName), pk, + dbcon.tryInsert( + SQLQuery(sql(q[1], q[0].tName, values)), pk, q[1].insertFields.values.toSeq) template exec*(q: SQLQuery): untyped =