Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for column and table names in repos #8

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/main/scala/com/augustnagro/magnum/ClickhouseDbType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,11 @@ object ClickhouseDbType extends DbType:
con: DbCon
): BatchUpdateResult =
throw UnsupportedOperationException()

def columns: AllColumns = AllColumns.fromSeq(eElemNamesSql)

def insertColumns: InsertColumns = InsertColumns.fromSeq(ecElemNamesSql)

def tableName: Repo.TableName = Repo.TableName(tableNameSql)

def idColumn: Repo.IdColumn = Repo.IdColumn(idName)
8 changes: 8 additions & 0 deletions src/main/scala/com/augustnagro/magnum/H2DbType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,11 @@ object H2DbType extends DbType:
) match
case Success(res) => res
case Failure(t) => throw SqlException(updateSql, entities, t)

def columns: AllColumns = AllColumns.fromSeq(eElemNamesSql)

def insertColumns: InsertColumns = InsertColumns.fromSeq(ecElemNamesSql)

def tableName: Repo.TableName = Repo.TableName(tableNameSql)

def idColumn: Repo.IdColumn = Repo.IdColumn(idName)
6 changes: 6 additions & 0 deletions src/main/scala/com/augustnagro/magnum/ImmutableRepo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,9 @@ open class ImmutableRepo[E, ID](using defaults: RepoDefaults[?, E, ID]):
*/
def findAllById(ids: Iterable[ID])(using DbCon): Vector[E] =
defaults.findAllById(ids)

def * : AllColumns = defaults.columns

def id: Repo.IdColumn = defaults.idColumn

def table: Repo.TableName = defaults.tableName
8 changes: 8 additions & 0 deletions src/main/scala/com/augustnagro/magnum/MySqlDbType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,11 @@ object MySqlDbType extends DbType:
) match
case Success(res) => res
case Failure(t) => throw SqlException(updateSql, entities, t)

def columns: AllColumns = AllColumns.fromSeq(eElemNamesSql)

def insertColumns: InsertColumns = InsertColumns.fromSeq(ecElemNamesSql)

def tableName: Repo.TableName = Repo.TableName(tableNameSql)

def idColumn: Repo.IdColumn = Repo.IdColumn(idName)
8 changes: 8 additions & 0 deletions src/main/scala/com/augustnagro/magnum/OracleDbType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,11 @@ object OracleDbType extends DbType:
) match
case Success(res) => res
case Failure(t) => throw SqlException(updateSql, entities, t)

def columns: AllColumns = AllColumns.fromSeq(eElemNamesSql)

def insertColumns: InsertColumns = InsertColumns.fromSeq(ecElemNamesSql)

def tableName: Repo.TableName = Repo.TableName(tableNameSql)

def idColumn: Repo.IdColumn = Repo.IdColumn(idName)
10 changes: 9 additions & 1 deletion src/main/scala/com/augustnagro/magnum/PostgresDbType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ object PostgresDbType extends DbType:
ps.executeUpdate()
) match
case Success(_) => ()
case Failure(t) => throw SqlException(updateSql, entity, t)
case Failure(t) => throw SqlException(updateSql, entity, t)

def updateAll(entities: Iterable[E])(using
con: DbCon
Expand All @@ -235,3 +235,11 @@ object PostgresDbType extends DbType:
) match
case Success(res) => res
case Failure(t) => throw SqlException(updateSql, entities, t)

def columns: AllColumns = AllColumns.fromSeq(eElemNamesSql)

def insertColumns: InsertColumns = InsertColumns.fromSeq(ecElemNamesSql)

def tableName: Repo.TableName = Repo.TableName(tableNameSql)

def idColumn: Repo.IdColumn = Repo.IdColumn(idName)
92 changes: 92 additions & 0 deletions src/main/scala/com/augustnagro/magnum/Renderables.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package com.augustnagro.magnum

import scala.collection.*

// template for all renderables

abstract class Renderables[+SELF <: IndexedSeq[String]](
private val strs: IArray[String],
private val factory: SpecificIterableFactory[String, SELF]
) extends IndexedSeq[String],
immutable.IndexedSeqOps[String, IndexedSeq, SELF],
immutable.StrictOptimizedSeqOps[String, IndexedSeq, SELF]:
self: SELF =>

def length: Int = strs.length

def apply(idx: Int): String =
if idx < 0 || length <= idx then throw new IndexOutOfBoundsException
strs(idx)

// Mandatory overrides of `fromSpecific`, `newSpecificBuilder`,
// and `empty`, from `IterableOps`
override protected def fromSpecific(coll: IterableOnce[String]): SELF =
factory.fromSpecific(coll)
override protected def newSpecificBuilder: mutable.Builder[String, SELF] =
factory.newBuilder
override def empty: SELF = factory.empty

// Overloading of `appended`, `prepended`, `appendedAll`, `prependedAll`,
// `map`, `flatMap` and `concat` to return a `Columns` when possible
def concat(suffix: IterableOnce[String]): SELF =
strictOptimizedConcat(suffix, newSpecificBuilder)
inline final def ++(suffix: IterableOnce[String]): SELF = concat(suffix)
def appended(String: String): SELF =
(newSpecificBuilder ++= this += String).result()
def appendedAll(suffix: Iterable[String]): SELF =
strictOptimizedConcat(suffix, newSpecificBuilder)
def prepended(String: String): SELF =
(newSpecificBuilder += String ++= this).result()
def prependedAll(prefix: Iterable[String]): SELF =
(newSpecificBuilder ++= prefix ++= this).result()
def map(f: String => String): SELF =
strictOptimizedMap(newSpecificBuilder, f)
def flatMap(f: String => IterableOnce[String]): SELF =
strictOptimizedFlatMap(newSpecificBuilder, f)

override def iterator: Iterator[String] = new AbstractIterator[String]:
private var current = 0
def hasNext = current < self.length
def next(): String =
val elem = self(current)
current += 1
elem

// concrete classes

class AllColumns(private val cols: IArray[String])
extends Renderables[AllColumns](cols, AllColumns):
override def className = "Columns"
AugustNagro marked this conversation as resolved.
Show resolved Hide resolved
override def toString(): String = mkString(", ")

class InsertColumns(private val cols: IArray[String])
extends Renderables[InsertColumns](cols, InsertColumns):
override def className = "InsertColumns"
override def toString(): String = mkString("(", ", ", ")")

// Factories

class RenderablesFactory[A <: Renderables[A]](
creator: IterableOnce[String] => A
) extends SpecificIterableFactory[String, A]:

def fromSeq(buf: collection.Seq[String]): A =
creator(buf)
def empty: A = fromSeq(Seq.empty)

def newBuilder: mutable.Builder[String, A] =
mutable.ArrayBuffer.newBuilder[String].mapResult(fromSeq)

def fromSpecific(it: IterableOnce[String]): A = it match
case seq: collection.Seq[String] => fromSeq(seq)
case _ => creator(it)

object AllColumns
extends RenderablesFactory[AllColumns](it =>
new AllColumns(IArray.from(it))
)

object InsertColumns
extends RenderablesFactory[InsertColumns](it =>
new InsertColumns(IArray.from(it))
)
15 changes: 15 additions & 0 deletions src/main/scala/com/augustnagro/magnum/Repo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,18 @@ open class Repo[EC, E, ID](using defaults: RepoDefaults[EC, E, ID])
/** Update all entities */
def updateAll(entities: Iterable[E])(using DbCon): BatchUpdateResult =
defaults.updateAll(entities)

def insertColumns: InsertColumns = defaults.insertColumns

object Repo:

trait Identifier

opaque type TableName <: Identifier = String & Identifier
opaque type IdColumn <: Identifier = String & Identifier

object TableName:
private[magnum] def apply(s: String): TableName = s.asInstanceOf[TableName]

object IdColumn:
private[magnum] def apply(s: String): IdColumn = s.asInstanceOf[IdColumn]
7 changes: 6 additions & 1 deletion src/main/scala/com/augustnagro/magnum/RepoDefaults.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ trait RepoDefaults[EC, E, ID]:
def insertAllReturning(entityCreators: Iterable[EC])(using DbCon): Vector[E]
def update(entity: E)(using DbCon): Unit
def updateAll(entities: Iterable[E])(using DbCon): BatchUpdateResult
// metadata to help with query building
def columns: AllColumns
def insertColumns: InsertColumns
def tableName: Repo.TableName
def idColumn: Repo.IdColumn

object RepoDefaults:

Expand Down Expand Up @@ -76,7 +81,7 @@ object RepoDefaults:
}
}) =>
val tableNameSql = sqlTableNameAnnot[E] match {
case Some(sqlName) =>
case Some(sqlName) =>
'{ $sqlName.name }
case None =>
val tableName = Expr(Type.valueOfConstant[eLabel].get.toString)
Expand Down
8 changes: 8 additions & 0 deletions src/main/scala/com/augustnagro/magnum/SqliteDbType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,11 @@ object SqliteDbType extends DbType:
) match
case Success(res) => res
case Failure(t) => throw SqlException(updateSql, entities, t)

def columns: AllColumns = AllColumns.fromSeq(eElemNamesSql)

def insertColumns: InsertColumns = InsertColumns.fromSeq(ecElemNamesSql)

def tableName: Repo.TableName = Repo.TableName(tableNameSql)

def idColumn: Repo.IdColumn = Repo.IdColumn(idName)
33 changes: 24 additions & 9 deletions src/main/scala/com/augustnagro/magnum/util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,27 @@ private def sqlImpl(sc: Expr[StringContext], args: Expr[Seq[Any]])(using
// val stringExprs: Seq[Expr[String]] = sc match
// case '{ StringContext(${ Varargs(strings) }: _*) } => strings

val paramsExpr = Expr.ofSeq(argsExprs)
val questionVarargs = Varargs(Vector.fill(argsExprs.size)(Expr("?")))
val queryExpr = '{ $sc.s($questionVarargs: _*) }
val interpolatedVarargs = Varargs(argsExprs.map {
case '{ $arg: Renderables[_] } => '{ $arg.toString() }
case '{ $arg: Repo.Identifier } => '{ $arg.toString() }
case '{ $arg: tp } =>
val codecExpr = summonWriter[tp]
'{ $codecExpr.queryRepr }
})

val paramExprs = argsExprs.filter {
case '{ $arg: Renderables[_] } => false
case '{ $arg: Repo.Identifier } => false
case _ => true
}

val queryExpr = '{ $sc.s($interpolatedVarargs: _*) }
val exprParams = Expr.ofSeq(paramExprs)

'{
val argValues = $args
val argValues = $exprParams
val writer: FragWriter = (ps: PreparedStatement, pos: Int) => {
${ sqlWriter('{ ps }, '{ pos }, '{ argValues }, argsExprs, '{ 0 }) }
${ sqlWriter('{ ps }, '{ pos }, '{ argValues }, paramExprs, '{ 0 }) }
}
Frag($queryExpr, argValues, writer)
}
Expand Down Expand Up @@ -108,12 +121,14 @@ private def sqlWriter(
private def summonWriter[T: Type](using Quotes): Expr[DbCodec[T]] =
import quotes.reflect.*

Expr.summon[DbCodec[T]]
Expr
.summon[DbCodec[T]]
.orElse(
TypeRepr.of[T].widen.asType match
case '[tpe] => Expr.summon[DbCodec[tpe]].map(codec =>
'{ $codec.asInstanceOf[DbCodec[T]] }
)
case '[tpe] =>
Expr
.summon[DbCodec[tpe]]
.map(codec => '{ $codec.asInstanceOf[DbCodec[T]] })
)
.getOrElse:
report.info(
Expand Down
35 changes: 32 additions & 3 deletions src/test/scala/ClickHouseTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ class ClickHouseTests extends FunSuite, TestContainersFixtures:
color: Color
) derives DbCodec

val carRepo = ImmutableRepo[Car, UUID]
object carRepo extends ImmutableRepo[Car, UUID]:
def customSelect(speed: Int): Vector[Car] =
val sql = sql"SELECT ${*} FROM $table WHERE top_speed > $speed"
connect(ds()):
sql.query[Car].run()
AugustNagro marked this conversation as resolved.
Show resolved Hide resolved

val allCars = Vector(
Car(
Expand Down Expand Up @@ -122,7 +126,8 @@ class ClickHouseTests extends FunSuite, TestContainersFixtures:
val vin = Some(124)
val cars =
sql"select * from car where vin = $vin"
.query[Car].run()
.query[Car]
.run()
assertEquals(cars, allCars.filter(_.vinNumber == vin))

test("tuple select"):
Expand All @@ -136,6 +141,10 @@ class ClickHouseTests extends FunSuite, TestContainersFixtures:
connect(ds()):
assertEquals(carRepo.findAll.last.vinNumber, None)

test("custom select using identifiers"):
connect(ds()):
assertEquals(carRepo.customSelect(211), Vector(allCars(1)))

/*
Repo Tests
*/
Expand All @@ -149,7 +158,12 @@ class ClickHouseTests extends FunSuite, TestContainersFixtures:
created: OffsetDateTime
) derives DbCodec

val personRepo = Repo[Person, Person, UUID]
object personRepo extends Repo[Person, Person, UUID]:
def customInsert(p: Person): Unit =
val sql =
sql"INSERT INTO $table $insertColumns VALUES ($p)"
connect(ds()):
sql.update.run()
AugustNagro marked this conversation as resolved.
Show resolved Hide resolved

test("delete"):
connect(ds()):
Expand Down Expand Up @@ -322,6 +336,21 @@ class ClickHouseTests extends FunSuite, TestContainersFixtures:
val count = transact(ds())(personRepo.count)
assertEquals(count, 8L)

test("custom insert"):
val person = Person(
id = UUID.randomUUID,
firstName = Some("John"),
lastName = "Smith",
isAdmin = false,
created = OffsetDateTime.of(2020, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC)
)

connect(ds()):
personRepo.customInsert(person)

assertEquals(personRepo.count, 9L)
assertEquals(personRepo.findById(person.id).get, person)

val clickHouseContainer = ForAllContainerFixture(
ClickHouseContainer
.Def(dockerImageName =
Expand Down
Loading