From a78af0cf89a6804e38ffe191e1778dd86b1f7394 Mon Sep 17 00:00:00 2001 From: Katrix Date: Sun, 26 May 2024 21:32:14 +0200 Subject: [PATCH] Add trig functions --- .github/workflows/makesite.yml | 2 +- .../sql/implementations/H2Platform.scala | 23 +++-- .../sql/implementations/MySqlPlatform.scala | 7 +- .../implementations/PostgresPlatform.scala | 12 ++- .../sql/implementations/SqlitePlatform.scala | 12 ++- .../value/SqlHyperbolicTrigFunctions.scala | 31 +++++++ .../platform/sql/value/SqlTrigFunctions.scala | 25 +++++ .../dataprism/sharedast/AstRenderer.scala | 7 ++ .../dataprism/sharedast/SharedSqlAst.scala | 7 ++ .../sharedast/SqliteAstRenderer.scala | 21 ++++- .../scala/dataprism/PlatformMathSuite.scala | 92 ++++++++++--------- .../scala/dataprism/jdbc/h2/H2MathSuite.scala | 7 ++ .../jdbc/mariadb/MariaDbMathSuite.scala | 10 +- .../jdbc/mysql57/MySql57MathSuite.scala | 11 +-- .../jdbc/mysql8/MySql8MathSuite.scala | 10 +- .../jdbc/postgres/PostgresMathSuite.scala | 11 +++ .../jdbc/sqlite/SqliteFunSuite.scala | 15 +-- .../jdbc/sqlite/SqliteMathSuite.scala | 12 ++- .../skunk/testsuites/PostgresMathSuite.scala | 11 +++ 19 files changed, 225 insertions(+), 101 deletions(-) create mode 100644 common/src/main/scala/dataprism/platform/sql/value/SqlHyperbolicTrigFunctions.scala create mode 100644 common/src/main/scala/dataprism/platform/sql/value/SqlTrigFunctions.scala diff --git a/.github/workflows/makesite.yml b/.github/workflows/makesite.yml index 629b4a3e..cafdd512 100644 --- a/.github/workflows/makesite.yml +++ b/.github/workflows/makesite.yml @@ -25,5 +25,5 @@ jobs: env: REPO: self BRANCH: gh-pages - FOLDER: docs/target/scala-3.3.1/unidoc + FOLDER: docs/target/scala-3.3.2/unidoc GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/common/src/main/scala/dataprism/platform/sql/implementations/H2Platform.scala b/common/src/main/scala/dataprism/platform/sql/implementations/H2Platform.scala index c4600848..4a1f0d90 100644 --- a/common/src/main/scala/dataprism/platform/sql/implementations/H2Platform.scala +++ b/common/src/main/scala/dataprism/platform/sql/implementations/H2Platform.scala @@ -1,10 +1,15 @@ package dataprism.platform.sql.implementations -import dataprism.platform.sql.value.SqlBitwiseOps +import dataprism.platform.sql.value.{SqlBitwiseOps, SqlHyperbolicTrigFunctions, SqlTrigFunctions} import dataprism.platform.sql.{DefaultCompleteSql, DefaultSqlOperations} import dataprism.sharedast.H2AstRenderer -trait H2Platform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps { +trait H2Platform + extends DefaultCompleteSql, + DefaultSqlOperations, + SqlBitwiseOps, + SqlTrigFunctions, + SqlHyperbolicTrigFunctions { platform => override type CastType[A] = Type[A] @@ -35,12 +40,12 @@ trait H2Platform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps f: MapUpdateReturning[Table, From, Res] ): (Table, From) => Res = f - given bitwiseByte: SqlBitwise[Byte] = SqlBitwise.defaultInstance - given bitwiseOptByte: SqlBitwise[Option[Byte]] = SqlBitwise.defaultInstance - given bitwiseShort: SqlBitwise[Short] = SqlBitwise.defaultInstance + given bitwiseByte: SqlBitwise[Byte] = SqlBitwise.defaultInstance + given bitwiseOptByte: SqlBitwise[Option[Byte]] = SqlBitwise.defaultInstance + given bitwiseShort: SqlBitwise[Short] = SqlBitwise.defaultInstance given bitwiseOptShort: SqlBitwise[Option[Short]] = SqlBitwise.defaultInstance - given bitwiseInt: SqlBitwise[Int] = SqlBitwise.defaultInstance - given bitwiseOptInt: SqlBitwise[Option[Int]] = SqlBitwise.defaultInstance + given bitwiseInt: SqlBitwise[Int] = SqlBitwise.defaultInstance + given bitwiseOptInt: SqlBitwise[Option[Int]] = SqlBitwise.defaultInstance type Api <: H2Api @@ -51,8 +56,8 @@ trait H2Platform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps lazy val sqlRenderer: H2AstRenderer[Codec] = new H2AstRenderer[Codec](AnsiTypes, [A] => (codec: Codec[A]) => codec.name) - type DbMath = SimpleSqlDbMath - object DbMath extends SimpleSqlDbMath + type DbMath = SimpleSqlDbMath & SqlTrigMath & SqlHyperbolicTrigMath + object DbMath extends SimpleSqlDbMath, SqlTrigMath, SqlHyperbolicTrigMath type DbValue[A] = SqlDbValue[A] override protected def sqlDbValueLift[A]: Lift[SqlDbValue[A], DbValue[A]] = Lift.subtype diff --git a/common/src/main/scala/dataprism/platform/sql/implementations/MySqlPlatform.scala b/common/src/main/scala/dataprism/platform/sql/implementations/MySqlPlatform.scala index 3b8c2ec0..619c1b17 100644 --- a/common/src/main/scala/dataprism/platform/sql/implementations/MySqlPlatform.scala +++ b/common/src/main/scala/dataprism/platform/sql/implementations/MySqlPlatform.scala @@ -1,8 +1,9 @@ package dataprism.platform.sql.implementations +import dataprism.platform.sql.value.SqlTrigFunctions import dataprism.platform.sql.{DefaultCompleteSql, DefaultSqlOperations} -trait MySqlPlatform extends DefaultCompleteSql, DefaultSqlOperations { platform => +trait MySqlPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlTrigFunctions { platform => override type InFilterCapability = Unit override type InMapCapability = Unit @@ -33,8 +34,8 @@ trait MySqlPlatform extends DefaultCompleteSql, DefaultSqlOperations { platform export platform.{given DeleteUsingCapability, given LateralJoinCapability} } - type DbMath = SimpleSqlDbMath - object DbMath extends SimpleSqlDbMath + type DbMath = SimpleSqlDbMath & SqlTrigMath + object DbMath extends SimpleSqlDbMath, SqlTrigMath type DbValue[A] = SqlDbValue[A] override protected def sqlDbValueLift[A]: Lift[SqlDbValue[A], DbValue[A]] = Lift.subtype diff --git a/common/src/main/scala/dataprism/platform/sql/implementations/PostgresPlatform.scala b/common/src/main/scala/dataprism/platform/sql/implementations/PostgresPlatform.scala index e07b683f..2ba6b38e 100644 --- a/common/src/main/scala/dataprism/platform/sql/implementations/PostgresPlatform.scala +++ b/common/src/main/scala/dataprism/platform/sql/implementations/PostgresPlatform.scala @@ -1,13 +1,13 @@ package dataprism.platform.sql.implementations import cats.syntax.all.* -import dataprism.platform.sql.value.SqlBitwiseOps +import dataprism.platform.sql.value.{SqlBitwiseOps, SqlHyperbolicTrigFunctions, SqlTrigFunctions} import dataprism.platform.sql.{DefaultCompleteSql, DefaultSqlOperations} import dataprism.sharedast.{PostgresAstRenderer, SqlExpr} import dataprism.sql.* //noinspection SqlNoDataSourceInspection, ScalaUnusedSymbol -trait PostgresPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps { platform => +trait PostgresPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps, SqlTrigFunctions, SqlHyperbolicTrigFunctions { platform => override type InFilterCapability = Unit override type InMapCapability = Unit @@ -36,6 +36,10 @@ trait PostgresPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitw given ExceptCapability with {} given IntersectCapability with {} + + given ASinhCapability with {} + given ACoshCapability with {} + given ATanhCapability with {} override type MapUpdateReturning[Table, From, Res] = (Table, From) => Res override protected def contramapUpdateReturning[Table, From, Res]( @@ -65,8 +69,8 @@ trait PostgresPlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitw override def castTypeName: String = t.name override def castTypeType: Type[A] = t - type DbMath = SimpleSqlDbMath - object DbMath extends SimpleSqlDbMath + type DbMath = SimpleSqlDbMath & SqlTrigMath & SqlHyperbolicTrigMath + object DbMath extends SimpleSqlDbMath, SqlTrigMath, SqlHyperbolicTrigMath type DbValue[A] = SqlDbValue[A] override protected def sqlDbValueLift[A]: Lift[SqlDbValue[A], DbValue[A]] = Lift.subtype diff --git a/common/src/main/scala/dataprism/platform/sql/implementations/SqlitePlatform.scala b/common/src/main/scala/dataprism/platform/sql/implementations/SqlitePlatform.scala index b8e70136..2719b704 100644 --- a/common/src/main/scala/dataprism/platform/sql/implementations/SqlitePlatform.scala +++ b/common/src/main/scala/dataprism/platform/sql/implementations/SqlitePlatform.scala @@ -1,10 +1,10 @@ package dataprism.platform.sql.implementations -import dataprism.platform.sql.value.SqlBitwiseOps +import dataprism.platform.sql.value.{SqlBitwiseOps, SqlHyperbolicTrigFunctions, SqlTrigFunctions} import dataprism.platform.sql.{DefaultCompleteSql, DefaultSqlOperations} import dataprism.sharedast.{SqlExpr, SqliteAstRenderer} -trait SqlitePlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps { platform => +trait SqlitePlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwiseOps, SqlTrigFunctions, SqlHyperbolicTrigFunctions { platform => override type CastType[A] = Type[A] @@ -37,6 +37,10 @@ trait SqlitePlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwis given ExceptCapability with {} given IntersectCapability with {} + given ASinhCapability with {} + given ACoshCapability with {} + given ATanhCapability with {} + override protected def generateDeleteAlias: Boolean = false override protected def generateUpdateAlias: Boolean = false @@ -60,8 +64,8 @@ trait SqlitePlatform extends DefaultCompleteSql, DefaultSqlOperations, SqlBitwis lazy val sqlRenderer: SqliteAstRenderer[Codec] = new SqliteAstRenderer[Codec](AnsiTypes, [A] => (codec: Codec[A]) => codec.name) - type DbMath = SimpleSqlDbMath - object DbMath extends SimpleSqlDbMath + type DbMath = SimpleSqlDbMath & SqlTrigMath & SqlHyperbolicTrigMath + object DbMath extends SimpleSqlDbMath, SqlTrigMath, SqlHyperbolicTrigMath type DbValue[A] = SqlDbValue[A] override protected def sqlDbValueLift[A]: Lift[SqlDbValue[A], DbValue[A]] = Lift.subtype diff --git a/common/src/main/scala/dataprism/platform/sql/value/SqlHyperbolicTrigFunctions.scala b/common/src/main/scala/dataprism/platform/sql/value/SqlHyperbolicTrigFunctions.scala new file mode 100644 index 00000000..f9623be1 --- /dev/null +++ b/common/src/main/scala/dataprism/platform/sql/value/SqlHyperbolicTrigFunctions.scala @@ -0,0 +1,31 @@ +package dataprism.platform.sql.value + +import dataprism.sharedast.SqlExpr + +trait SqlHyperbolicTrigFunctions extends SqlDbValuesBase { + + type DbMath <: SqlHyperbolicTrigMath + + trait ASinhCapability + trait ACoshCapability + trait ATanhCapability + + trait SqlHyperbolicTrigMath: + def acosh(v: DbValue[Double])(using ACoshCapability): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.ACosh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + + def asinh(v: DbValue[Double])(using ASinhCapability): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.ASinh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + + def atanh(v: DbValue[Double])(using ATanhCapability): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.ATanh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + + def cosh(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.Cosh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + + def sinh(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.Sinh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + + def tanh(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.Tanh, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) +} diff --git a/common/src/main/scala/dataprism/platform/sql/value/SqlTrigFunctions.scala b/common/src/main/scala/dataprism/platform/sql/value/SqlTrigFunctions.scala new file mode 100644 index 00000000..26e1cbb6 --- /dev/null +++ b/common/src/main/scala/dataprism/platform/sql/value/SqlTrigFunctions.scala @@ -0,0 +1,25 @@ +package dataprism.platform.sql.value + +import dataprism.sharedast.SqlExpr + +trait SqlTrigFunctions extends SqlDbValuesBase { + + type DbMath <: SqlTrigMath + trait SqlTrigMath: + def acos(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.ACos, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + def asin(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.ASin, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + def atan(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.ATan, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + def atan2(v1: DbValue[Double], v2: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.ATan2, Seq(v1.unsafeAsAnyDbVal, v2.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + def cos(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.Cos, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + def cot(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.Cot, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + def sin(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.Sin, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) + def tan(v: DbValue[Double]): DbValue[Double] = + Impl.function(SqlExpr.FunctionName.Tan, Seq(v.unsafeAsAnyDbVal), AnsiTypes.doublePrecision) +} diff --git a/common/src/main/scala/dataprism/sharedast/AstRenderer.scala b/common/src/main/scala/dataprism/sharedast/AstRenderer.scala index 1c833d8d..06b5317a 100644 --- a/common/src/main/scala/dataprism/sharedast/AstRenderer.scala +++ b/common/src/main/scala/dataprism/sharedast/AstRenderer.scala @@ -89,6 +89,13 @@ class AstRenderer[Codec[_]](ansiTypes: AnsiTypes[Codec], getCodecTypeName: [A] = case SqlExpr.FunctionName.Sin => normal("sin") case SqlExpr.FunctionName.Tan => normal("tan") + case SqlExpr.FunctionName.ACosh => normal("acosh") + case SqlExpr.FunctionName.ASinh => normal("asinh") + case SqlExpr.FunctionName.ATanh => normal("atanh") + case SqlExpr.FunctionName.Cosh => normal("cosh") + case SqlExpr.FunctionName.Sinh => normal("sinh") + case SqlExpr.FunctionName.Tanh => normal("tanh") + case SqlExpr.FunctionName.Abs => normal("abs") case SqlExpr.FunctionName.Avg => normal("avg") case SqlExpr.FunctionName.Count => normal("count") diff --git a/common/src/main/scala/dataprism/sharedast/SharedSqlAst.scala b/common/src/main/scala/dataprism/sharedast/SharedSqlAst.scala index 2c50eb7e..5c385121 100644 --- a/common/src/main/scala/dataprism/sharedast/SharedSqlAst.scala +++ b/common/src/main/scala/dataprism/sharedast/SharedSqlAst.scala @@ -80,6 +80,13 @@ object SqlExpr { case Sin case Tan + case ACosh + case ASinh + case ATanh + case Cosh + case Sinh + case Tanh + case Greatest case Least diff --git a/common/src/main/scala/dataprism/sharedast/SqliteAstRenderer.scala b/common/src/main/scala/dataprism/sharedast/SqliteAstRenderer.scala index 782ed7aa..a5c092fa 100644 --- a/common/src/main/scala/dataprism/sharedast/SqliteAstRenderer.scala +++ b/common/src/main/scala/dataprism/sharedast/SqliteAstRenderer.scala @@ -19,7 +19,11 @@ class SqliteAstRenderer[Codec[_]](ansiTypes: AnsiTypes[Codec], getCodecTypeName: case SqlExpr.BinaryOperation.Concat => sql"(${renderExpr(lhs)} || ${renderExpr(rhs)})" case _ => super.renderBinaryOp(lhs, rhs, op, tpe) - override protected def renderFunctionCall(call: SqlExpr.FunctionName, args: Seq[SqlExpr[Codec]], tpe: String): SqlStr[Codec] = + override protected def renderFunctionCall( + call: SqlExpr.FunctionName, + args: Seq[SqlExpr[Codec]], + tpe: String + ): SqlStr[Codec] = inline def rendered = args.map(renderExpr).intercalate(sql", ") inline def normal(f: String): SqlStr[Codec] = sql"${SqlStr.const(f)}($rendered)" @@ -28,7 +32,8 @@ class SqliteAstRenderer[Codec[_]](ansiTypes: AnsiTypes[Codec], getCodecTypeName: val slidingArgs = args.sliding(99).toSeq val (handled, notHandleds) = slidingArgs.splitAt(99) val handledArgs = handled.map(args => SqlExpr.FunctionCall(call, args, tpe)) - val notHandledArg = if notHandleds.nonEmpty then Seq(SqlExpr.FunctionCall(call, notHandleds.flatten, tpe)) else Nil + val notHandledArg = + if notHandleds.nonEmpty then Seq(SqlExpr.FunctionCall(call, notHandleds.flatten, tpe)) else Nil val allArgs = handledArgs ++ notHandledArg @@ -40,7 +45,17 @@ class SqliteAstRenderer[Codec[_]](ansiTypes: AnsiTypes[Codec], getCodecTypeName: case SqlExpr.FunctionName.Greatest => subdivided("max") case SqlExpr.FunctionName.Min => normal("min") case SqlExpr.FunctionName.Least => subdivided("min") - case _ => super.renderFunctionCall(call, args, tpe) + + case SqlExpr.FunctionName.Cot => + renderExpr( + SqlExpr.BinOp( + SqlExpr.FunctionCall(SqlExpr.FunctionName.Cos, args, tpe), + SqlExpr.FunctionCall(SqlExpr.FunctionName.Sin, args, tpe), + SqlExpr.BinaryOperation.Divide, + tpe + ) + ) + case _ => super.renderFunctionCall(call, args, tpe) override protected def renderFrom(from: SelectAst.From[Codec]): SqlStr[Codec] = from match case SelectAst.From.FromQuery(_, _, true) => throw new SQLException("H2 does not support lateral") diff --git a/common/src/test/scala/dataprism/PlatformMathSuite.scala b/common/src/test/scala/dataprism/PlatformMathSuite.scala index 0ec3a07e..b86e6f91 100644 --- a/common/src/test/scala/dataprism/PlatformMathSuite.scala +++ b/common/src/test/scala/dataprism/PlatformMathSuite.scala @@ -4,12 +4,14 @@ import cats.Show import cats.effect.IO import cats.syntax.all.* import dataprism.platform.sql.SqlQueryPlatform +import dataprism.platform.sql.value.SqlTrigFunctions import org.scalacheck.Gen import org.scalacheck.cats.implicits.* import weaver.{Expectations, Log, SourceLocation} trait PlatformMathSuite[Codec0[_], Platform <: SqlQueryPlatform { type Codec[A] = Codec0[A] }] - extends PlatformFunSuite[Codec0, Platform], WithOptionalMath: + extends PlatformFunSuite[Codec0, Platform], + WithOptionalMath: import platform.Api.* import platform.{AnsiTypes, name} @@ -29,8 +31,7 @@ trait PlatformMathSuite[Codec0[_], Platform <: SqlQueryPlatform { type Codec[A] case class TypeInfo[A]( tpe: Type[A], gen: Gen[A], - epsilon: A, - isNan: A => Boolean, + areEqual: (A, A) => Expectations, transform2: Gen[(A, A)] => Gen[(A, A)] = identity[Gen[(A, A)]] )(using val numeric: Numeric[A], val show: Show[A]) @@ -40,50 +41,37 @@ trait PlatformMathSuite[Codec0[_], Platform <: SqlQueryPlatform { type Codec[A] name: String, dbLevel: CastType[A] => DbValue[A], valueLevel: () => A - ): Unit = typeLogTest(name, castType.castTypeType): log => + )(using SourceLocation): Unit = typeLogTest(name, castType.castTypeType): log => given Log[IO] = log - import typeInfo.given - import Numeric.Implicits.* - import Ordering.Implicits.* logQuery(Select(Query.of(dbLevel(castType)))) .flatMap(_.runOne[F]) - .map: r => - val v = valueLevel() - expect((v - r <= typeInfo.epsilon) || (typeInfo.isNan(v) && typeInfo.isNan(r))) + .map(r => typeInfo.areEqual(valueLevel(), r)) def functionTest1[A: SqlNumeric]( name: String, typeInfo: TypeInfo[A], dbLevel: DbValue[A] => DbValue[A], valueLevel: A => A - ): Unit = typeLogTest(name, typeInfo.tpe): log => + )(using SourceLocation): Unit = typeLogTest(name, typeInfo.tpe): log => given Log[IO] = log import typeInfo.given - import Numeric.Implicits.* - import Ordering.Implicits.* configuredForall(typeInfo.gen): (a: A) => logQuery(Select(Query.of(dbLevel(a.as(typeInfo.tpe))))) .flatMap(_.runOne[F]) - .map: r => - val v = valueLevel(a) - expect((v - r <= typeInfo.epsilon) || (typeInfo.isNan(v) && typeInfo.isNan(r))) + .map(r => typeInfo.areEqual(valueLevel(a), r)) def functionTest2[A: SqlNumeric]( name: String, typeInfo: TypeInfo[A], dbLevel: (DbValue[A], DbValue[A]) => DbValue[A], valueLevel: (A, A) => A - ): Unit = typeLogTest(name, typeInfo.tpe): log => + )(using SourceLocation): Unit = typeLogTest(name, typeInfo.tpe): log => given Log[IO] = log import typeInfo.given - import Numeric.Implicits.* - import Ordering.Implicits.* configuredForall(typeInfo.transform2((typeInfo.gen, typeInfo.gen).tupled)): (a: A, b: A) => logQuery(Select(Query.of(dbLevel(a.as(typeInfo.tpe), b.as(typeInfo.tpe))))) .flatMap(_.runOne[F]) - .map: r => - val v = valueLevel(a, b) - expect((v - r <= typeInfo.epsilon) || (typeInfo.isNan(v) && typeInfo.isNan(r))) + .map(r => typeInfo.areEqual(valueLevel(a, b), r)) def testPow[A: SqlNumeric](typeInfo: TypeInfo[A], op: (A, A) => A): Unit = functionTest2("pow", typeInfo, DbMath.pow, op) @@ -133,13 +121,43 @@ trait PlatformMathSuite[Codec0[_], Platform <: SqlQueryPlatform { type Codec[A] .runOne[F] .map(r => expect(true)) - private val longGen = Gen.choose(-10000L, 10000L) + private def doubleGen = Gen.choose(-10000D, 10000D) + + def doubleTypeInfo: TypeInfo[Double] = TypeInfo( + AnsiTypes.doublePrecision, + Gen.choose(-10000D, 10000D), + (v: Double, r: Double) => { + val epsilon = 0.00001 + import java.lang.Double as JDouble + + expect( + Math.abs(v - r) <= epsilon + || (JDouble.isNaN(v) && JDouble.isNaN(r)) + || (JDouble.isInfinite(v) && JDouble.isInfinite(r) && Math.signum(v) == Math.signum(r)) + || (Math.getExponent(v) == Math.getExponent(r) + && (v / Math.pow(2, Math.getExponent(v))) - (r / Math.pow(2, Math.getExponent(r))) <= epsilon) //Compare the exponents and mantissa seperately. Sadly it seems like we do sometimes reach this case + ) + } + ) + + def doublePositiveTypeInfo: TypeInfo[Double] = doubleTypeInfo.copy(gen = Gen.choose(0, 10000D)) + + def testTrigFunctions[A](platformWithTrig: platform.type & SqlTrigFunctions): Unit = + functionTest1("acos", doubleTypeInfo.copy(gen = Gen.choose(-1, 1)), platformWithTrig.DbMath.acos, Math.acos) + functionTest1("asin", doubleTypeInfo.copy(gen = Gen.choose(-1, 1)), platformWithTrig.DbMath.asin, Math.asin) + functionTest1("atan", doubleTypeInfo, platformWithTrig.DbMath.atan, Math.atan) + functionTest2("atan2", doubleTypeInfo, platformWithTrig.DbMath.atan2, Math.atan2) + functionTest1("cos", doubleTypeInfo, platformWithTrig.DbMath.cos, Math.cos) + functionTest1("cot", doubleTypeInfo, platformWithTrig.DbMath.cot, (x: Double) => Math.cos(x) / Math.sin(x)) + functionTest1("sin", doubleTypeInfo, platformWithTrig.DbMath.sin, Math.sin) + functionTest1("tan", doubleTypeInfo, platformWithTrig.DbMath.tan, Math.tan) - val longTypeInfo: TypeInfo[Long] = TypeInfo( + private def longGen = Gen.choose(-10000L, 10000L) + + def longTypeInfo: TypeInfo[Long] = TypeInfo( AnsiTypes.bigint, Gen.choose(-10000L, 10000L), - 0, - _ => false + (v, r) => expect(v == r) ) protected type LongLikeCastType @@ -169,23 +187,12 @@ trait PlatformMathSuite[Codec0[_], Platform <: SqlQueryPlatform { type Codec[A] testPi(longCastType, longLikeTypeInfo, () => doubleToLongLikeCastType(Math.PI)) testRandom(longCastType) - private val doubleGen = Gen.choose(-10000D, 10000D) - protected type DoubleLikeCastType protected def doubleCastType: CastType[DoubleLikeCastType] protected def doubleLikeTypeInfo: TypeInfo[DoubleLikeCastType] protected def doubleToDoubleLikeCastType(d: Double): DoubleLikeCastType protected given doubleLikeCastTypeSqlNumeric: SqlNumeric[DoubleLikeCastType] - val doubleTypeInfo: TypeInfo[Double] = TypeInfo( - AnsiTypes.doublePrecision, - Gen.choose(-10000D, 10000D), - 0.00001, - java.lang.Double.isNaN - ) - - val doublePositiveTypeInfo: TypeInfo[Double] = doubleTypeInfo.copy(gen = Gen.choose(0, 10000D)) - testPow( doubleTypeInfo.copy( gen = Gen.choose(-10D, 10D), @@ -216,20 +223,19 @@ trait PlatformMathSuite[Codec0[_], Platform <: SqlQueryPlatform { type Codec[A] testPi(doubleCastType, doubleLikeTypeInfo, () => doubleToDoubleLikeCastType(Math.PI)) testRandom(doubleCastType) - val decimalTypeInfo: TypeInfo[BigDecimal] = TypeInfo( + def decimalTypeInfo: TypeInfo[BigDecimal] = TypeInfo( AnsiTypes.decimalN(15, 9), Gen.choose(BigDecimal(-10000D), BigDecimal(10000D)), - 0.00001, - _ => false + (v, r) => expect((v - r).abs < BigDecimal(0.00001)) ) - val decimalPositiveTypeInfo: TypeInfo[BigDecimal] = + def decimalPositiveTypeInfo: TypeInfo[BigDecimal] = decimalTypeInfo.copy(gen = Gen.choose(BigDecimal(0), BigDecimal(10000))) import spire.implicits.* // Hard to test - //testPow( + // testPow( // decimalTypeInfo.copy( // gen = Gen.choose(BigDecimal(-10), BigDecimal(10)), // transform2 = _ => { @@ -245,7 +251,7 @@ trait PlatformMathSuite[Codec0[_], Platform <: SqlQueryPlatform { type Codec[A] // (a: BigDecimal, b: BigDecimal) => { // if b.isWhole then a.pow(b.toInt) else spire.math.exp(a.log * b) // } - //) + // ) testSqrt(decimalPositiveTypeInfo, (_: BigDecimal).sqrt) testAbs(decimalTypeInfo, (_: BigDecimal).abs) testCeil(decimalTypeInfo, (_: BigDecimal).setScale(0, BigDecimal.RoundingMode.CEILING)) diff --git a/jdbc/src/test/scala/dataprism/jdbc/h2/H2MathSuite.scala b/jdbc/src/test/scala/dataprism/jdbc/h2/H2MathSuite.scala index 57d02a53..d6c5a0f4 100644 --- a/jdbc/src/test/scala/dataprism/jdbc/h2/H2MathSuite.scala +++ b/jdbc/src/test/scala/dataprism/jdbc/h2/H2MathSuite.scala @@ -4,8 +4,15 @@ import dataprism.jdbc.platform.H2JdbcPlatform import dataprism.jdbc.sql import dataprism.jdbc.sql.{JdbcAnsiTypes, JdbcCodec} import dataprism.{PlatformSaneMathSuite, jdbc} +import spire.math.Real object H2MathSuite extends H2FunSuite, PlatformSaneMathSuite[JdbcCodec, H2JdbcPlatform] { override protected def longCastType: JdbcAnsiTypes.TypeOf[Long] = JdbcAnsiTypes.bigint override protected def doubleCastType: JdbcAnsiTypes.TypeOf[Double] = JdbcAnsiTypes.doublePrecision + + testTrigFunctions(platform) + + functionTest1("sinh", doubleTypeInfo, platform.DbMath.sinh, Math.sinh) + functionTest1("cosh", doubleTypeInfo, platform.DbMath.cosh, Math.cosh) + functionTest1("tanh", doubleTypeInfo, platform.DbMath.tanh, Math.tanh) } diff --git a/jdbc/src/test/scala/dataprism/jdbc/mariadb/MariaDbMathSuite.scala b/jdbc/src/test/scala/dataprism/jdbc/mariadb/MariaDbMathSuite.scala index 09bf3918..595f2f35 100644 --- a/jdbc/src/test/scala/dataprism/jdbc/mariadb/MariaDbMathSuite.scala +++ b/jdbc/src/test/scala/dataprism/jdbc/mariadb/MariaDbMathSuite.scala @@ -3,7 +3,6 @@ package dataprism.jdbc.mariadb import dataprism.PlatformMathSuite import dataprism.jdbc.platform.MariaDbJdbcPlatform import dataprism.jdbc.sql.{JdbcCodec, MySqlJdbcTypeCastable, MySqlJdbcTypes} -import org.scalacheck.Gen object MariaDbMathSuite extends MariaDbFunSuite, PlatformMathSuite[JdbcCodec, MariaDbJdbcPlatform] { import platform.Api.* @@ -17,12 +16,9 @@ object MariaDbMathSuite extends MariaDbFunSuite, PlatformMathSuite[JdbcCodec, Ma override protected type DoubleLikeCastType = BigDecimal override protected def doubleCastType: MySqlJdbcTypeCastable[BigDecimal] = MySqlJdbcTypes.castType.decimalN(15, 9) - override protected def doubleLikeTypeInfo: TypeInfo[BigDecimal] = TypeInfo( - MySqlJdbcTypes.decimal, - Gen.choose(BigDecimal(-10000), BigDecimal(10000)), - BigDecimal(0.000000001), - _ => false - ) + override protected def doubleLikeTypeInfo: TypeInfo[BigDecimal] = decimalTypeInfo override protected def doubleToDoubleLikeCastType(d: Double): BigDecimal = BigDecimal.decimal(d) override protected def doubleLikeCastTypeSqlNumeric: SqlNumeric[BigDecimal] = platform.sqlNumericBigDecimal + + testTrigFunctions(platform) } diff --git a/jdbc/src/test/scala/dataprism/jdbc/mysql57/MySql57MathSuite.scala b/jdbc/src/test/scala/dataprism/jdbc/mysql57/MySql57MathSuite.scala index a28b9638..67df8ffa 100644 --- a/jdbc/src/test/scala/dataprism/jdbc/mysql57/MySql57MathSuite.scala +++ b/jdbc/src/test/scala/dataprism/jdbc/mysql57/MySql57MathSuite.scala @@ -1,10 +1,8 @@ package dataprism.jdbc.mysql57 import dataprism.PlatformMathSuite -import dataprism.jdbc.mariadb.MariaDbMathSuite.platform import dataprism.jdbc.platform.MySql57JdbcPlatform import dataprism.jdbc.sql.{JdbcCodec, MySqlJdbcTypeCastable, MySqlJdbcTypes} -import org.scalacheck.Gen object MySql57MathSuite extends MySql57FunSuite, PlatformMathSuite[JdbcCodec, MySql57JdbcPlatform] { import platform.Api.* @@ -18,12 +16,9 @@ object MySql57MathSuite extends MySql57FunSuite, PlatformMathSuite[JdbcCodec, My override protected type DoubleLikeCastType = BigDecimal override protected def doubleCastType: MySqlJdbcTypeCastable[BigDecimal] = MySqlJdbcTypes.castType.decimalN(15, 9) - override protected def doubleLikeTypeInfo: TypeInfo[BigDecimal] = TypeInfo( - MySqlJdbcTypes.decimal, - Gen.choose(BigDecimal(-10000), BigDecimal(10000)), - BigDecimal(0.000000001), - _ => false - ) + override protected def doubleLikeTypeInfo: TypeInfo[BigDecimal] = decimalTypeInfo override protected def doubleToDoubleLikeCastType(d: Double): BigDecimal = BigDecimal.decimal(d) override protected def doubleLikeCastTypeSqlNumeric: SqlNumeric[BigDecimal] = platform.sqlNumericBigDecimal + + testTrigFunctions(platform) } diff --git a/jdbc/src/test/scala/dataprism/jdbc/mysql8/MySql8MathSuite.scala b/jdbc/src/test/scala/dataprism/jdbc/mysql8/MySql8MathSuite.scala index 2b82f768..1ab6dcc8 100644 --- a/jdbc/src/test/scala/dataprism/jdbc/mysql8/MySql8MathSuite.scala +++ b/jdbc/src/test/scala/dataprism/jdbc/mysql8/MySql8MathSuite.scala @@ -3,7 +3,6 @@ package dataprism.jdbc.mysql8 import dataprism.PlatformMathSuite import dataprism.jdbc.platform.MySql8JdbcPlatform import dataprism.jdbc.sql.{JdbcCodec, MySqlJdbcTypeCastable, MySqlJdbcTypes} -import org.scalacheck.Gen object MySql8MathSuite extends MySql8FunSuite, PlatformMathSuite[JdbcCodec, MySql8JdbcPlatform] { import platform.Api.* @@ -17,12 +16,9 @@ object MySql8MathSuite extends MySql8FunSuite, PlatformMathSuite[JdbcCodec, MySq override protected type DoubleLikeCastType = BigDecimal override protected def doubleCastType: MySqlJdbcTypeCastable[BigDecimal] = MySqlJdbcTypes.castType.decimalN(15, 9) - override protected def doubleLikeTypeInfo: TypeInfo[BigDecimal] = TypeInfo( - MySqlJdbcTypes.decimal, - Gen.choose(BigDecimal(-10000), BigDecimal(10000)), - BigDecimal(0.000000001), - _ => false - ) + override protected def doubleLikeTypeInfo: TypeInfo[BigDecimal] = decimalTypeInfo override protected def doubleToDoubleLikeCastType(d: Double): BigDecimal = BigDecimal.decimal(d) override protected def doubleLikeCastTypeSqlNumeric: SqlNumeric[BigDecimal] = platform.sqlNumericBigDecimal + + testTrigFunctions(platform) } diff --git a/jdbc/src/test/scala/dataprism/jdbc/postgres/PostgresMathSuite.scala b/jdbc/src/test/scala/dataprism/jdbc/postgres/PostgresMathSuite.scala index d5b56a5c..bfc0c911 100644 --- a/jdbc/src/test/scala/dataprism/jdbc/postgres/PostgresMathSuite.scala +++ b/jdbc/src/test/scala/dataprism/jdbc/postgres/PostgresMathSuite.scala @@ -3,10 +3,21 @@ package dataprism.jdbc.postgres import dataprism.PlatformSaneMathSuite import dataprism.jdbc.platform.PostgresJdbcPlatform import dataprism.jdbc.sql.{JdbcAnsiTypes, JdbcCodec} +import org.scalacheck.Gen +import spire.math.Real object PostgresMathSuite extends PostgresFunSuite, PlatformSaneMathSuite[JdbcCodec, PostgresJdbcPlatform] { override def maxParallelism: Int = 10 override protected def longCastType: JdbcAnsiTypes.TypeOf[Long] = JdbcAnsiTypes.bigint override protected def doubleCastType: JdbcAnsiTypes.TypeOf[Double] = JdbcAnsiTypes.doublePrecision + + testTrigFunctions(platform) + + functionTest1("sinh", doubleTypeInfo, platform.DbMath.sinh, Math.sinh) + functionTest1("cosh", doubleTypeInfo, platform.DbMath.cosh, Math.cosh) + functionTest1("tanh", doubleTypeInfo, platform.DbMath.tanh, Math.tanh) + functionTest1("asinh", doubleTypeInfo, platform.DbMath.asinh, a => Real.asinh(Real(a)).toDouble) + functionTest1("acosh", doubleTypeInfo.copy(gen = Gen.choose(1D, 10000D)), platform.DbMath.acosh, a => Real.acosh(Real(a)).toDouble) + functionTest1("atanh", doubleTypeInfo.copy(gen = Gen.choose(-1D, 1D)), platform.DbMath.atanh, a => Real.atanh(Real(a)).toDouble) } diff --git a/jdbc/src/test/scala/dataprism/jdbc/sqlite/SqliteFunSuite.scala b/jdbc/src/test/scala/dataprism/jdbc/sqlite/SqliteFunSuite.scala index c1089017..008d255e 100644 --- a/jdbc/src/test/scala/dataprism/jdbc/sqlite/SqliteFunSuite.scala +++ b/jdbc/src/test/scala/dataprism/jdbc/sqlite/SqliteFunSuite.scala @@ -1,28 +1,21 @@ package dataprism.jdbc.sqlite -import java.nio.file.{Files, Paths} - import cats.effect.{IO, Resource} import dataprism.PlatformFunSuite import dataprism.PlatformFunSuite.DbToTest import dataprism.jdbc.CatsDataSourceDb import dataprism.jdbc.platform.SqliteJdbcPlatform import dataprism.jdbc.sql.JdbcCodec -import org.sqlite.SQLiteDataSource abstract class SqliteFunSuite extends PlatformFunSuite[JdbcCodec, SqliteJdbcPlatform](SqliteJdbcPlatform) { self => def dbToTest: DbToTest = DbToTest.Sqlite override def sharedResource: Resource[IO, DbType] = { - val simpleName = self.getClass.getSimpleName + val simpleName = self.getClass.getSimpleName val simplerName = if simpleName.endsWith("$") then simpleName.substring(0, simpleName.length - 1) else simpleName - Resource.make(IO { - val ds = org.sqlite.SQLiteDataSource() - ds.setUrl(s"jdbc:sqlite:$simplerName.db") - CatsDataSourceDb[IO](ds) - }) { _ => - IO.blocking(Files.deleteIfExists(Paths.get(s"./$simplerName.db"))) - } + val ds = org.sqlite.SQLiteDataSource() + ds.setUrl(s"jdbc:sqlite::memory:") + Resource.pure(CatsDataSourceDb[IO](ds)) } } diff --git a/jdbc/src/test/scala/dataprism/jdbc/sqlite/SqliteMathSuite.scala b/jdbc/src/test/scala/dataprism/jdbc/sqlite/SqliteMathSuite.scala index d32c9da3..f1045b68 100644 --- a/jdbc/src/test/scala/dataprism/jdbc/sqlite/SqliteMathSuite.scala +++ b/jdbc/src/test/scala/dataprism/jdbc/sqlite/SqliteMathSuite.scala @@ -1,11 +1,21 @@ package dataprism.jdbc.sqlite -import cats.effect.IO import dataprism.PlatformSaneMathSuite import dataprism.jdbc.platform.SqliteJdbcPlatform import dataprism.jdbc.sql.{JdbcAnsiTypes, JdbcCodec} +import org.scalacheck.Gen +import spire.math.Real object SqliteMathSuite extends SqliteFunSuite, PlatformSaneMathSuite[JdbcCodec, SqliteJdbcPlatform] { override protected def longCastType: JdbcAnsiTypes.TypeOf[Long] = JdbcAnsiTypes.bigint override protected def doubleCastType: JdbcAnsiTypes.TypeOf[Double] = JdbcAnsiTypes.doublePrecision + + testTrigFunctions(platform) + + functionTest1("sinh", doubleTypeInfo, platform.DbMath.sinh, Math.sinh) + functionTest1("cosh", doubleTypeInfo, platform.DbMath.cosh, Math.cosh) + functionTest1("tanh", doubleTypeInfo, platform.DbMath.tanh, Math.tanh) + functionTest1("asinh", doubleTypeInfo, platform.DbMath.asinh, a => Real.asinh(Real(a)).toDouble) + functionTest1("acosh", doubleTypeInfo.copy(gen = Gen.choose(1D, 10000D)), platform.DbMath.acosh, a => Real.acosh(Real(a)).toDouble) + functionTest1("atanh", doubleTypeInfo.copy(gen = Gen.choose(-1D, 1D)), platform.DbMath.atanh, a => Real.atanh(Real(a)).toDouble) } diff --git a/skunk/src/test/scala/dataprism/skunk/testsuites/PostgresMathSuite.scala b/skunk/src/test/scala/dataprism/skunk/testsuites/PostgresMathSuite.scala index 48a35adf..32a61b45 100644 --- a/skunk/src/test/scala/dataprism/skunk/testsuites/PostgresMathSuite.scala +++ b/skunk/src/test/scala/dataprism/skunk/testsuites/PostgresMathSuite.scala @@ -3,7 +3,9 @@ package dataprism.skunk.testsuites import dataprism.PlatformSaneMathSuite import dataprism.skunk.platform.PostgresSkunkPlatform import dataprism.skunk.sql.SkunkAnsiTypes +import org.scalacheck.Gen import skunk.Codec +import spire.math.Real object PostgresMathSuite extends PostgresFunSuite, PlatformSaneMathSuite[Codec, PostgresSkunkPlatform] { // override def maxParallelism: Int = 10 @@ -13,4 +15,13 @@ object PostgresMathSuite extends PostgresFunSuite, PlatformSaneMathSuite[Codec, override protected def longCastType: SkunkAnsiTypes.TypeOf[Long] = SkunkAnsiTypes.bigint override protected def doubleCastType: SkunkAnsiTypes.TypeOf[Double] = SkunkAnsiTypes.doublePrecision + + testTrigFunctions(platform) + + functionTest1("sinh", doubleTypeInfo, platform.DbMath.sinh, Math.sinh) + functionTest1("cosh", doubleTypeInfo, platform.DbMath.cosh, Math.cosh) + functionTest1("tanh", doubleTypeInfo, platform.DbMath.tanh, Math.tanh) + functionTest1("asinh", doubleTypeInfo, platform.DbMath.asinh, a => Real.asinh(Real(a)).toDouble) + functionTest1("acosh", doubleTypeInfo.copy(gen = Gen.choose(1D, 10000D)), platform.DbMath.acosh, a => Real.acosh(Real(a)).toDouble) + functionTest1("atanh", doubleTypeInfo.copy(gen = Gen.choose(-1D, 1D)), platform.DbMath.atanh, a => Real.atanh(Real(a)).toDouble) }