Skip to content

Commit

Permalink
Add trig functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrix committed May 26, 2024
1 parent af14b6f commit a78af0c
Show file tree
Hide file tree
Showing 19 changed files with 225 additions and 101 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/makesite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ jobs:
env:
REPO: self
BRANCH: gh-pages
FOLDER: docs/target/scala-3.3.1/unidoc
FOLDER: docs/target/scala-3.3.2/unidoc
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package dataprism.platform.sql.implementations

import dataprism.platform.sql.value.SqlBitwiseOps
import dataprism.platform.sql.value.{SqlBitwiseOps, SqlHyperbolicTrigFunctions, SqlTrigFunctions}
import dataprism.platform.sql.{DefaultCompleteSql, DefaultSqlOperations}
import dataprism.sharedast.H2AstRenderer

trait H2Platform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps {
trait H2Platform
extends DefaultCompleteSql,
DefaultSqlOperations,
SqlBitwiseOps,
SqlTrigFunctions,
SqlHyperbolicTrigFunctions {
platform =>

override type CastType[A] = Type[A]
Expand Down Expand Up @@ -35,12 +40,12 @@ trait H2Platform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps
f: MapUpdateReturning[Table, From, Res]
): (Table, From) => Res = f

given bitwiseByte: SqlBitwise[Byte] = SqlBitwise.defaultInstance
given bitwiseOptByte: SqlBitwise[Option[Byte]] = SqlBitwise.defaultInstance
given bitwiseShort: SqlBitwise[Short] = SqlBitwise.defaultInstance
given bitwiseByte: SqlBitwise[Byte] = SqlBitwise.defaultInstance
given bitwiseOptByte: SqlBitwise[Option[Byte]] = SqlBitwise.defaultInstance
given bitwiseShort: SqlBitwise[Short] = SqlBitwise.defaultInstance
given bitwiseOptShort: SqlBitwise[Option[Short]] = SqlBitwise.defaultInstance
given bitwiseInt: SqlBitwise[Int] = SqlBitwise.defaultInstance
given bitwiseOptInt: SqlBitwise[Option[Int]] = SqlBitwise.defaultInstance
given bitwiseInt: SqlBitwise[Int] = SqlBitwise.defaultInstance
given bitwiseOptInt: SqlBitwise[Option[Int]] = SqlBitwise.defaultInstance

type Api <: H2Api

Expand All @@ -51,8 +56,8 @@ trait H2Platform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps
lazy val sqlRenderer: H2AstRenderer[Codec] =
new H2AstRenderer[Codec](AnsiTypes, [A] => (codec: Codec[A]) => codec.name)

type DbMath = SimpleSqlDbMath
object DbMath extends SimpleSqlDbMath
type DbMath = SimpleSqlDbMath & SqlTrigMath & SqlHyperbolicTrigMath
object DbMath extends SimpleSqlDbMath, SqlTrigMath, SqlHyperbolicTrigMath

type DbValue[A] = SqlDbValue[A]
override protected def sqlDbValueLift[A]: Lift[SqlDbValue[A], DbValue[A]] = Lift.subtype
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package dataprism.platform.sql.implementations

import dataprism.platform.sql.value.SqlTrigFunctions
import dataprism.platform.sql.{DefaultCompleteSql, DefaultSqlOperations}

trait MySqlPlatform extends DefaultCompleteSql, DefaultSqlOperations { platform =>
trait MySqlPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlTrigFunctions { platform =>

override type InFilterCapability = Unit
override type InMapCapability = Unit
Expand Down Expand Up @@ -33,8 +34,8 @@ trait MySqlPlatform extends DefaultCompleteSql, DefaultSqlOperations { platform
export platform.{given DeleteUsingCapability, given LateralJoinCapability}
}

type DbMath = SimpleSqlDbMath
object DbMath extends SimpleSqlDbMath
type DbMath = SimpleSqlDbMath & SqlTrigMath
object DbMath extends SimpleSqlDbMath, SqlTrigMath

type DbValue[A] = SqlDbValue[A]
override protected def sqlDbValueLift[A]: Lift[SqlDbValue[A], DbValue[A]] = Lift.subtype
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package dataprism.platform.sql.implementations

import cats.syntax.all.*
import dataprism.platform.sql.value.SqlBitwiseOps
import dataprism.platform.sql.value.{SqlBitwiseOps, SqlHyperbolicTrigFunctions, SqlTrigFunctions}
import dataprism.platform.sql.{DefaultCompleteSql, DefaultSqlOperations}
import dataprism.sharedast.{PostgresAstRenderer, SqlExpr}
import dataprism.sql.*

//noinspection SqlNoDataSourceInspection, ScalaUnusedSymbol
trait PostgresPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps { platform =>
trait PostgresPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps, SqlTrigFunctions, SqlHyperbolicTrigFunctions { platform =>

override type InFilterCapability = Unit
override type InMapCapability = Unit
Expand Down Expand Up @@ -36,6 +36,10 @@ trait PostgresPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitw

given ExceptCapability with {}
given IntersectCapability with {}

given ASinhCapability with {}
given ACoshCapability with {}
given ATanhCapability with {}

override type MapUpdateReturning[Table, From, Res] = (Table, From) => Res
override protected def contramapUpdateReturning[Table, From, Res](
Expand Down Expand Up @@ -65,8 +69,8 @@ trait PostgresPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitw
override def castTypeName: String = t.name
override def castTypeType: Type[A] = t

type DbMath = SimpleSqlDbMath
object DbMath extends SimpleSqlDbMath
type DbMath = SimpleSqlDbMath & SqlTrigMath & SqlHyperbolicTrigMath
object DbMath extends SimpleSqlDbMath, SqlTrigMath, SqlHyperbolicTrigMath

type DbValue[A] = SqlDbValue[A]
override protected def sqlDbValueLift[A]: Lift[SqlDbValue[A], DbValue[A]] = Lift.subtype
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package dataprism.platform.sql.implementations

import dataprism.platform.sql.value.SqlBitwiseOps
import dataprism.platform.sql.value.{SqlBitwiseOps, SqlHyperbolicTrigFunctions, SqlTrigFunctions}
import dataprism.platform.sql.{DefaultCompleteSql, DefaultSqlOperations}
import dataprism.sharedast.{SqlExpr, SqliteAstRenderer}

trait SqlitePlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps { platform =>
trait SqlitePlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps, SqlTrigFunctions, SqlHyperbolicTrigFunctions { platform =>

override type CastType[A] = Type[A]

Expand Down Expand Up @@ -37,6 +37,10 @@ trait SqlitePlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwis
given ExceptCapability with {}
given IntersectCapability with {}

given ASinhCapability with {}
given ACoshCapability with {}
given ATanhCapability with {}

override protected def generateDeleteAlias: Boolean = false
override protected def generateUpdateAlias: Boolean = false

Expand All @@ -60,8 +64,8 @@ trait SqlitePlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwis
lazy val sqlRenderer: SqliteAstRenderer[Codec] =
new SqliteAstRenderer[Codec](AnsiTypes, [A] => (codec: Codec[A]) => codec.name)

type DbMath = SimpleSqlDbMath
object DbMath extends SimpleSqlDbMath
type DbMath = SimpleSqlDbMath & SqlTrigMath & SqlHyperbolicTrigMath
object DbMath extends SimpleSqlDbMath, SqlTrigMath, SqlHyperbolicTrigMath

type DbValue[A] = SqlDbValue[A]
override protected def sqlDbValueLift[A]: Lift[SqlDbValue[A], DbValue[A]] = Lift.subtype
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package dataprism.platform.sql.value

import dataprism.sharedast.SqlExpr

trait SqlHyperbolicTrigFunctions extends SqlDbValuesBase {

type DbMath <: SqlHyperbolicTrigMath

trait ASinhCapability
trait ACoshCapability
trait ATanhCapability

trait SqlHyperbolicTrigMath:
def acosh(v: DbValue[Double])(using ACoshCapability): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.ACosh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)

def asinh(v: DbValue[Double])(using ASinhCapability): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.ASinh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)

def atanh(v: DbValue[Double])(using ATanhCapability): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.ATanh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)

def cosh(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.Cosh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)

def sinh(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.Sinh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)

def tanh(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.Tanh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package dataprism.platform.sql.value

import dataprism.sharedast.SqlExpr

trait SqlTrigFunctions extends SqlDbValuesBase {

type DbMath <: SqlTrigMath
trait SqlTrigMath:
def acos(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.ACos, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)
def asin(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.ASin, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)
def atan(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.ATan, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)
def atan2(v1: DbValue[Double], v2: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.ATan2, Seq(v1.unsafeAsAnyDbVal, v2.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)
def cos(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.Cos, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)
def cot(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.Cot, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)
def sin(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.Sin, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)
def tan(v: DbValue[Double]): DbValue[Double] =
Impl.function(SqlExpr.FunctionName.Tan, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision)
}
7 changes: 7 additions & 0 deletions common/src/main/scala/dataprism/sharedast/AstRenderer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ class AstRenderer[Codec[_]](ansiTypes: AnsiTypes[Codec], getCodecTypeName: [A] =
case SqlExpr.FunctionName.Sin => normal("sin")
case SqlExpr.FunctionName.Tan => normal("tan")

case SqlExpr.FunctionName.ACosh => normal("acosh")
case SqlExpr.FunctionName.ASinh => normal("asinh")
case SqlExpr.FunctionName.ATanh => normal("atanh")
case SqlExpr.FunctionName.Cosh => normal("cosh")
case SqlExpr.FunctionName.Sinh => normal("sinh")
case SqlExpr.FunctionName.Tanh => normal("tanh")

case SqlExpr.FunctionName.Abs => normal("abs")
case SqlExpr.FunctionName.Avg => normal("avg")
case SqlExpr.FunctionName.Count => normal("count")
Expand Down
7 changes: 7 additions & 0 deletions common/src/main/scala/dataprism/sharedast/SharedSqlAst.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ object SqlExpr {
case Sin
case Tan

case ACosh
case ASinh
case ATanh
case Cosh
case Sinh
case Tanh

case Greatest
case Least

Expand Down
21 changes: 18 additions & 3 deletions common/src/main/scala/dataprism/sharedast/SqliteAstRenderer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ class SqliteAstRenderer[Codec[_]](ansiTypes: AnsiTypes[Codec], getCodecTypeName:
case SqlExpr.BinaryOperation.Concat => sql"(${renderExpr(lhs)} || ${renderExpr(rhs)})"
case _ => super.renderBinaryOp(lhs, rhs, op, tpe)

override protected def renderFunctionCall(call: SqlExpr.FunctionName, args: Seq[SqlExpr[Codec]], tpe: String): SqlStr[Codec] =
override protected def renderFunctionCall(
call: SqlExpr.FunctionName,
args: Seq[SqlExpr[Codec]],
tpe: String
): SqlStr[Codec] =
inline def rendered = args.map(renderExpr).intercalate(sql", ")
inline def normal(f: String): SqlStr[Codec] = sql"${SqlStr.const(f)}($rendered)"

Expand All @@ -28,7 +32,8 @@ class SqliteAstRenderer[Codec[_]](ansiTypes: AnsiTypes[Codec], getCodecTypeName:
val slidingArgs = args.sliding(99).toSeq
val (handled, notHandleds) = slidingArgs.splitAt(99)
val handledArgs = handled.map(args => SqlExpr.FunctionCall(call, args, tpe))
val notHandledArg = if notHandleds.nonEmpty then Seq(SqlExpr.FunctionCall(call, notHandleds.flatten, tpe)) else Nil
val notHandledArg =
if notHandleds.nonEmpty then Seq(SqlExpr.FunctionCall(call, notHandleds.flatten, tpe)) else Nil

val allArgs = handledArgs ++ notHandledArg

Expand All @@ -40,7 +45,17 @@ class SqliteAstRenderer[Codec[_]](ansiTypes: AnsiTypes[Codec], getCodecTypeName:
case SqlExpr.FunctionName.Greatest => subdivided("max")
case SqlExpr.FunctionName.Min => normal("min")
case SqlExpr.FunctionName.Least => subdivided("min")
case _ => super.renderFunctionCall(call, args, tpe)

case SqlExpr.FunctionName.Cot =>
renderExpr(
SqlExpr.BinOp(
SqlExpr.FunctionCall(SqlExpr.FunctionName.Cos, args, tpe),
SqlExpr.FunctionCall(SqlExpr.FunctionName.Sin, args, tpe),
SqlExpr.BinaryOperation.Divide,
tpe
)
)
case _ => super.renderFunctionCall(call, args, tpe)

override protected def renderFrom(from: SelectAst.From[Codec]): SqlStr[Codec] = from match
case SelectAst.From.FromQuery(_, _, true) => throw new SQLException("H2 does not support lateral")
Expand Down
Loading

0 comments on commit a78af0c

Please sign in to comment.