Skip to content

Commit

Permalink
safe SqlQuery | add limit + exists
Browse files Browse the repository at this point in the history
Signed-off-by: George Lemon <[email protected]>
  • Loading branch information
georgelemon committed Apr 25, 2024
1 parent 1f2311d commit b5c2a0c
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 46 deletions.
18 changes: 12 additions & 6 deletions src/enimsql/model.nim
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ macro initTable*(model: untyped) =
var
execCall = newCall(ident"exec", ident"dbcon")
sqlCall = newCall ident"sql"
add sqlCall, newLit sql(createStmt, table)
var values: seq[string]
add sqlCall, newLit sql(createStmt, table, values)
add execCall, sqlCall
add result, execCall
else:
Expand All @@ -310,7 +311,8 @@ macro tryCreate*[M](model: typedesc[M], then: untyped) =
var
execCall = newCall(ident"tryExec", ident"dbcon")
sqlCall = newCall ident"sql"
add sqlCall, newLit sql(createStmt, table)
var values: seq[string]
add sqlCall, newLit sql(createStmt, table, values)
add execCall, sqlCall
add result,
nnkBlockStmt.newTree(
Expand All @@ -330,15 +332,17 @@ macro delete*[M](model: typedesc[M]) =
result = newStmtList()
let table = getTableName($model)
var dropStmt: Query = newDropStmt()
exec sql(dropStmt, table)
var values: seq[string]
exec sql(dropStmt, table, values)

macro delete*(model: untyped, then: untyped) =
## Delete a table represented by `Model`
checkModelExists(model)
result = newStmtList()
let table = getTableName($model)
var dropStmt: Query = newDropStmt()
tryExec sql(dropStmt, table)
var values: seq[string]
tryExec sql(dropStmt, table, values)

template drop*(model: untyped) =
## Delete a table represented by `Model`
Expand All @@ -350,7 +354,8 @@ macro clear*(model: untyped, then: untyped) =
result = newStmtList()
let table = getTableName($model)
var clearStmt: Query = newClearStmt()
executeSQL sql(clearStmt, table):
var values: seq[string]
executeSQL sql(clearStmt, table, values):
then

macro insertRow*(model: untyped, row: untyped,
Expand Down Expand Up @@ -380,7 +385,8 @@ macro insertRow*(model: untyped, row: untyped,
var
callExec = newCall(ident"tryInsertID", ident"dbcon")
callSql = newCall ident"sql"
add callSql, newLit sql(insertStmt, table)
sqlValues: seq[string]
add callSql, newLit sql(insertStmt, table, sqlValues)
add callExec, callSql
for v in values:
add callExec, v
Expand Down
140 changes: 100 additions & 40 deletions src/enimsql/private/query.nim
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type
ntWhere
ntWhereLike
ntWhereExistsStmt
ntOr
ntInc
ntDec
ntSet
Expand All @@ -26,6 +27,7 @@ type
ntInfix
ntComp
ntField
ntLimit

SQLOperator* = enum
EQ = "="
Expand Down Expand Up @@ -81,17 +83,22 @@ type
selectTable*: string
selectCondition*: Query # ntWhere
selectOrder*: seq[(string, Order)]
selectQueryFilters*: seq[Query]
of ntInsert, ntUpsert:
insertFields*: OrderedTable[string, string]
insertReturn*: Query # ntReturn node
of ntInfix:
infixOp: SQLOperator
infixLeft, infixRight: string
of ntWhere:
whereBranches*: seq[Query] # ntInfix
whereBranches*: seq[Query] # ntInfix
of ntOr:
orBranches*: seq[Query] # ntInfix
of ntUpdate, ntUpdateAll:
updateFields: seq[(SQLColumn, string)]
updateCondition*: Query
of ntLimit:
limitNumber: int
of ntDrop:
discard
of ntTruncate:
Expand Down Expand Up @@ -149,6 +156,8 @@ proc newWhereStmt*: Query = Query(nt: ntWhere)
proc newSelectStmt*: Query = Query(nt: ntSelect)
proc newUpdateStmt*: Query = Query(nt: ntUpdate)
proc newUpdateAllStmt*: Query = Query(nt: ntUpdateAll)
proc newOrStmt*: Query = Query(nt: ntOr)
proc newLimitFilter*(i: int): Query = Query(nt: ntLimit, limitNumber: i)

proc newInfixExpr*(lhs, rhs: string, op: SQLOperator): Query =
result = Query(nt: ntInfix)
Expand Down Expand Up @@ -217,7 +226,7 @@ proc q*(key: string): string {.compileTime.} =
else:
stmtSqlite[key]

proc sql*(node: Query, k: string): string =
proc sql*(node: Query, k: string, values: var seq[string]): string =
## Transform given SQL Query to stringified SQL
case node.nt:
of ntCreate:
Expand Down Expand Up @@ -264,33 +273,54 @@ proc sql*(node: Query, k: string): string =
else:
add result, node.selectColumns.mapIt(indent(it, 1)).join(",")
add result, indent("FROM", 1)
add result, indent(node.selectTable, 1)
add result, node.selectTable.escape.indent(1)
if node.selectCondition != nil:
add result, sql(node.selectCondition, k)
add result, sql(node.selectCondition, k, values)
if node.selectOrder.len > 0:
add result, indent(q("orderby") %
join(node.selectOrder.mapIt(it[0] & indent($(it[1]), 1)), ","), 1)
for filter in node.selectQueryFilters:
add result, indent(sql(filter, k, values), 1)
of ntWhere:
for branch in node.whereBranches:
add result, q("where").indent(1)
var val: string
when nimvm:
val =
case StaticSchema[k].tColumns[branch.infixLeft].cType
of Boolean, Int, Numeric, Money, Serial:
branch.infixRight
else:
"'" & branch.infixRight & "'"
else:
val =
case Models[k].tColumns[branch.infixLeft].cType
of Boolean, Int, Numeric, Money, Serial:
branch.infixRight
else:
"'" & branch.infixRight & "'"
add result, indent(branch.infixLeft, 1)
add result, indent($branch.infixOp, 1)
add result, indent(val, 1)
case branch.nt
of ntInfix:
add result, q("where").indent(1)
# var val: string
# when nimvm:
# val =
# case StaticSchema[k].tColumns[branch.infixLeft].cType
# of Boolean, Int, Numeric, Money, Serial:
# # branch.infixRight
# "'" & branch.infixRight & "'"
# else:
# "'" & branch.infixRight & "'"
# else:
# val =
# case Models[k].tColumns[branch.infixLeft].cType
# of Boolean, Int, Numeric, Money, Serial:
# "'" & branch.infixRight & "'"
# else:
# "'" & branch.infixRight & "'"
add result, branch.infixLeft.escape.indent(1)
add result, indent($(branch.infixOp), 1)
# add result, branch.infixRight.escape(prefix = "'", suffix = "'").indent(1)
add result, indent("?", 1)
add values, branch.infixRight
of ntOr:
add result, sql(branch, k, values)
else: discard
of ntOr:
result = indent($OR, 1)
for branch in node.orBranches:
case branch.nt
of ntInfix:
add result, indent(branch.infixLeft.escape, 1)
add result, indent($branch.infixOp, 1)
add result, indent("?", 1)
add values, branch.infixRight
else: discard # todo
# add result, sql(node, k)
of ntUpdate, ntUpdateAll:
result = q("update") % k
var updates: seq[string]
Expand All @@ -306,7 +336,7 @@ proc sql*(node: Query, k: string): string =
case node.nt
of ntUpdate:
if likely(node.updateCondition != nil):
add result, sql(node.updateCondition, k)
add result, sql(node.updateCondition, k, values)
else: discard # ntUpdateAll
of ntDrop:
result = q("drop") % k
Expand All @@ -321,27 +351,30 @@ proc sql*(node: Query, k: string): string =
let total =
if node.insertFields.len == 0: 0
else: node.insertFields.len - 1
var cols, values = indent("(", 1)
var cols, vals = indent("(", 1)
for k, v in node.insertFields:
add cols, k
add values, "?"
add vals, "?"
add values, v
if i != total:
add cols, "," & spaces(1)
add values, "," & spaces(1)
add vals, "," & spaces(1)
inc i
add cols, indent(")", 0)
add values, indent(")", 0)
add vals, indent(")", 0)
add result, cols
add result, indent("VALUES", 1)
add result, values
add result, vals
# case node.nt
# of ntUpsert:
# add result, "ON CONFLICT($1)" % "id"
# else: discard
setLen(cols, 0)
setLen(values, 0)
setLen(vals, 0)
# if node.insertReturn != nil:
# add result, sql(node.insertReturn, k)
of ntLimit:
result = indent("LIMIT " & $(node.limitNumber), 1)
of ntReturn:
result = indent(q("returning") % [node.returnColName], 1)
else: discard
Expand Down Expand Up @@ -377,7 +410,8 @@ proc create*(models: SchemaTable,
for k, col in schemaTable:
add createStmt.createColumns, col
if createStmt.createColumns.len != 0:
result = SQLQuery sql(createStmt, tableName)
var values: seq[string]
result = SQLQuery sql(createStmt, tableName, values)
models[modelName] = (tableName, schemaTable)

proc add*(schema: Schema, name: string,
Expand Down Expand Up @@ -475,17 +509,21 @@ proc where*(q: QueryBuilder, key: string,
q[1].selectCondition = whereStmt
of ntUpdate:
q[1].updateCondition = whereStmt
of ntOr:
add q[1].orBranches, whereStmt.whereBranches
else:
raise newException(EnimsqlQueryDefect,
"Invalid use of where statement for " & $q[1].nt)
"Invalid use of `WHERE` statement for " & $q[1].nt)

proc where*(q: QueryBuilder, key, val: string): QueryBuilder {.discardable.} =
result = q.where(key, EQ, val)

proc orWhere*(q: QueryBuilder, handle: proc(q: QueryBuilder)): QueryBuilder =
## Use the `orWhere` proc to join a clause to the
## query using the `or` operator
handle(q)
var q2 = (q[0], newOrStmt())
handle(q2)
add q[1].selectCondition.whereBranches, q2[1]
result = q

proc update*(model: Model,
Expand Down Expand Up @@ -534,11 +572,27 @@ proc orderBy*(q: QueryBuilder, key: string, order: Order = Order.ASC): QueryBuil
checkColumn key:
add q[1].selectOrder, (key, order)

proc limit*(q: QueryBuilder, i: int): QueryBuilder =
## Add a `LIMIT` filter to current `QueryBuilder`
add q[1].selectQueryFilters, newLimitFilter(i)
result = q

template exists*(m: Model, colName, expectValue: string): bool =
## Search in the current table if `colName` contains `expectValue`
# https://stackoverflow.com/questions/8149596/check-value-if-exists-in-column
# todo a nice way to check entries by value
var values: seq[string]
let x = m.select.where(colName, expectValue).limit(1)
let q = sql(x[1], m.tName, values)
let res = dbcon.getRow(SQLQuery("SELECT EXISTS (" & q & ")"), values)
len(res) > 0

template getAll*(q: QueryBuilder): untyped =
## Execute the query and returns a `Collection`
## instance with the available results
var rows = getAllRows(dbcon,
SQLQuery(sql(q[1], q[0].tName)))
var values: seq[string]
let x = SQLQuery(sql(q[1], q[0].tName, values))
var rows = getAllRows(dbcon, x, values)
var results = initCollection[SQLValue]()
if rows.len > 0:
if q[1].selectColumns.len > 0:
Expand Down Expand Up @@ -585,28 +639,34 @@ macro `@`*(x: untyped): untyped =
template getAll*(q: QueryBuilder, T: typedesc): untyped =
## Execute the query and returns a collection of objects `T`.
## This works only for a `Model` defined at compile-time
let results = getAllRows(dbcon,
SQLQuery(sql(q[1], q[0].tName)))
var values: seq[string]
let x = SQLQuery(sql(q[1], q[0].tName))
let results = getAllRows(dbcon, x, values)
var collections: seq[T]
for res in results:
add collections, initModel(T, res)
collections

template exec*(q: QueryBuilder): untyped =
## Use it inside a `withDB` context to execute a query
var values: seq[string]
case q[1].nt
of ntUpdate:
assert q[1].updateCondition != nil
dbcon.exec(SQLQuery sql(q[1], q[0].tName))
let x = SQLQuery(sql(q[1], q[0].tName, values))
dbcon.exec(x, values)
of ntInsert:
dbcon.exec(SQLQuery sql(q[1], q[0].tName), q[1].insertFields.values.toSeq)
let x = SQLQuery(sql(q[1], q[0].tName, values))
dbcon.exec(x, values) # q[1].insertFields.values.toSeq
else: discard # todo other final checks before executing the query

template execGet*(q: QueryBuilder, pk = "id"): untyped =
## Use it inside a `withDB` context to execute an `INSERT` query.
## Use a different `pk` name if your primary key is not named `id`.
var values: seq[string]
assert q[1].nt == ntInsert
dbcon.tryInsert(SQLQuery sql(q[1], q[0].tName), pk,
dbcon.tryInsert(
SQLQuery(sql(q[1], q[0].tName, values)), pk,
q[1].insertFields.values.toSeq)

template exec*(q: SQLQuery): untyped =
Expand Down

0 comments on commit b5c2a0c

Please sign in to comment.