Skip to content

Commit

Permalink
fix: EXPOSED-432 CurrentDate default is generated as null in MariaDB
Browse files Browse the repository at this point in the history
  • Loading branch information
joc-a committed Jul 15, 2024
1 parent 835573e commit bea4d78
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,17 @@ object SchemaUtils {

is Function<*> -> {
var processed = processForDefaultValue(exp)
if (exp.columnType is IDateColumnType && (processed.startsWith("CURRENT_TIMESTAMP") || processed == "GETDATE()")) {
when (currentDialect) {
is SQLServerDialect -> processed = "getdate"
is MariaDBDialect -> processed = processed.lowercase()
if (exp.columnType is IDateColumnType) {
if (processed.startsWith("CURRENT_TIMESTAMP") || processed == "GETDATE()") {
when (currentDialect) {
is SQLServerDialect -> processed = "getdate"
is MariaDBDialect -> processed = processed.lowercase()
}
}
if (processed.trim('(').startsWith("CURRENT_DATE")) {
when (currentDialect) {
is MysqlDialect -> processed = "curdate()"
}
}
}
processed
Expand Down Expand Up @@ -313,7 +320,12 @@ object SchemaUtils {
val dataTypeProvider = currentDialect.dataTypeProvider
val redoColumns = existingTableColumns.mapValues { (col, existingCol) ->
val columnType = col.columnType
val incorrectNullability = existingCol.nullable != columnType.nullable
val colNullable = if (col.dbDefaultValue?.let { currentDialect.isAllowedAsColumnDefault(it) } == false) {
true // Treat a disallowed default value as null because that is what Exposed does with it
} else {
columnType.nullable
}
val incorrectNullability = existingCol.nullable != colNullable
// Exposed doesn't support changing sequences on columns
val incorrectAutoInc = existingCol.autoIncrement != columnType.isAutoInc && col.autoIncColumnType?.autoincSeq == null

Expand Down Expand Up @@ -350,7 +362,7 @@ object SchemaUtils {
*/
private fun isIncorrectDefault(dataTypeProvider: DataTypeProvider, columnMeta: ColumnMetadata, column: Column<*>): Boolean {
val isExistingColumnDefaultNull = columnMeta.defaultDbValue == null
val isDefinedColumnDefaultNull = column.dbDefaultValue == null ||
val isDefinedColumnDefaultNull = column.dbDefaultValue?.takeIf { currentDialect.isAllowedAsColumnDefault(it) } == null ||
(column.dbDefaultValue is LiteralOp<*> && (column.dbDefaultValue as? LiteralOp<*>)?.value == null)

return when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jetbrains.exposed.sql.vendors
import org.jetbrains.exposed.exceptions.UnsupportedByDialectException
import org.jetbrains.exposed.exceptions.throwUnsupportedException
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.Function
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.math.BigDecimal

Expand Down Expand Up @@ -57,13 +58,22 @@ internal object MysqlDataTypeProvider : DataTypeProvider() {
override fun processForDefaultValue(e: Expression<*>): String = when {
e is LiteralOp<*> && e.columnType is JsonColumnMarker -> when {
currentDialect is MariaDBDialect -> super.processForDefaultValue(e)
(currentDialect as? MysqlDialect)?.isMysql8 == true -> "(${super.processForDefaultValue(e)})"
((currentDialect as? MysqlDialect)?.fullVersion ?: "0") >= "8.0.13" -> "(${super.processForDefaultValue(e)})"
else -> throw UnsupportedByDialectException(
"MySQL versions prior to 8.0.13 do not accept default values on JSON columns",
currentDialect
)
}

currentDialect is MariaDBDialect -> super.processForDefaultValue(e)
// The default value specified in a DEFAULT clause can be a literal constant or an expression. With one
// exception, enclose expression default values within parentheses to distinguish them from literal constant
// default values. The exception is that, for TIMESTAMP and DATETIME columns, you can specify the
// CURRENT_TIMESTAMP function as the default, without enclosing parentheses.
// https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html#data-type-defaults-explicit
e is Function<*> && e.columnType is IDateColumnType && e.toString().startsWith("CURRENT_TIMESTAMP") ->
super.processForDefaultValue(e)
e is Function<*> && ((currentDialect as? MysqlDialect)?.fullVersion ?: "0") >= "8.0.13" ->
"(${super.processForDefaultValue(e)})"
else -> super.processForDefaultValue(e)
}

Expand Down Expand Up @@ -308,7 +318,10 @@ open class MysqlDialect : VendorDialect(dialectName, MysqlDataTypeProvider, Mysq

override fun isAllowedAsColumnDefault(e: Expression<*>): Boolean {
if (super.isAllowedAsColumnDefault(e)) return true
val acceptableDefaults = arrayOf("CURRENT_TIMESTAMP", "CURRENT_TIMESTAMP()", "NOW()", "CURRENT_TIMESTAMP(6)", "NOW(6)")
if ((currentDialect is MariaDBDialect && fullVersion >= "10.2.1") || (currentDialect !is MariaDBDialect && fullVersion >= "8.0.13")) {
return true
}
val acceptableDefaults = mutableListOf("CURRENT_TIMESTAMP", "CURRENT_TIMESTAMP()", "NOW()", "CURRENT_TIMESTAMP(6)", "NOW(6)")
return e.toString().trim() in acceptableDefaults && isFractionDateTimeSupported()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jetbrains.exposed.sql.javatime
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.Function
import org.jetbrains.exposed.sql.vendors.H2Dialect
import org.jetbrains.exposed.sql.vendors.MariaDBDialect
import org.jetbrains.exposed.sql.vendors.MysqlDialect
import org.jetbrains.exposed.sql.vendors.SQLServerDialect
import org.jetbrains.exposed.sql.vendors.currentDialect
Expand Down Expand Up @@ -55,6 +56,7 @@ sealed class CurrentTimestampBase<T>(columnType: IColumnType<T & Any>) : Functio
object CurrentDate : Function<LocalDate>(JavaLocalDateColumnType.INSTANCE) {
override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder {
+when (currentDialect) {
is MariaDBDialect -> "curdate()"
is MysqlDialect -> "CURRENT_DATE()"
is SQLServerDialect -> "GETDATE()"
else -> "CURRENT_DATE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,17 +292,10 @@ class DefaultsTest : DatabaseTestsBase() {

@Test
fun testDefaultExpressions01() {
fun abs(value: Int) = object : ExpressionWithColumnType<Int>() {
override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder { append("ABS($value)") }

override val columnType: IColumnType<Int> = IntegerColumnType()
}

val foo = object : IntIdTable("foo") {
val name = text("name")
val defaultDateTime = datetime("defaultDateTime").defaultExpression(CurrentDateTime)
val defaultDate = date("defaultDate").defaultExpression(CurrentDate)
val defaultInt = integer("defaultInteger").defaultExpression(abs(-100))
}

withTables(foo) {
Expand All @@ -313,7 +306,6 @@ class DefaultsTest : DatabaseTestsBase() {

assertEquals(today, result[foo.defaultDateTime].toLocalDate())
assertEquals(today, result[foo.defaultDate])
assertEquals(100, result[foo.defaultInt])
}
}

Expand Down Expand Up @@ -368,8 +360,7 @@ class DefaultsTest : DatabaseTestsBase() {
val foo = object : IntIdTable("foo") {
val name = text("name")
val defaultDate = date("default_date").defaultExpression(CurrentDate)
val defaultDateTime1 = datetime("default_date_time_1").defaultExpression(CurrentDateTime)
val defaultDateTime2 = datetime("default_date_time_2").defaultExpression(CurrentDateTime)
val defaultDateTime = datetime("default_date_time").defaultExpression(CurrentDateTime)
val defaultTimeStamp = timestamp("default_time_stamp").defaultExpression(CurrentTimestamp)
}

Expand All @@ -379,16 +370,7 @@ class DefaultsTest : DatabaseTestsBase() {

val actual = SchemaUtils.statementsRequiredToActualizeScheme(foo)

if (currentDialectTest is MysqlDialect) {
// MySQL and MariaDB do not support CURRENT_DATE as default
// so the column is created with a NULL marker, which correctly triggers 1 alter statement
val tableName = foo.nameInDatabaseCase()
val dateColumnName = foo.defaultDate.nameInDatabaseCase()
val alter = "ALTER TABLE $tableName MODIFY COLUMN $dateColumnName DATE NULL"
assertEquals(alter, actual.single())
} else {
assertTrue(actual.isEmpty())
}
assertTrue(actual.isEmpty())
} finally {
SchemaUtils.drop(foo)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.Json
import org.jetbrains.exposed.dao.id.IntIdTable
import org.jetbrains.exposed.dao.id.LongIdTable
import org.jetbrains.exposed.exceptions.UnsupportedByDialectException
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.SqlExpressionBuilder.between
Expand Down Expand Up @@ -593,6 +594,17 @@ class JavaTimeTests : DatabaseTestsBase() {
)
}
}

@Test
fun testCurrentDateAsDefaultExpression() {
val testTable = object : LongIdTable("test_table") {
val date: Column<LocalDate> = date("date").index().defaultExpression(CurrentDate)
}
withTables(testTable) {
val statements = SchemaUtils.statementsRequiredForDatabaseMigration(testTable)
assertTrue(statements.isEmpty())
}
}
}

fun <T : Temporal> assertEqualDateTime(d1: T?, d2: T?) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jetbrains.exposed.sql.jodatime
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.Function
import org.jetbrains.exposed.sql.vendors.H2Dialect
import org.jetbrains.exposed.sql.vendors.MariaDBDialect
import org.jetbrains.exposed.sql.vendors.MysqlDialect
import org.jetbrains.exposed.sql.vendors.SQLServerDialect
import org.jetbrains.exposed.sql.vendors.currentDialect
Expand Down Expand Up @@ -38,6 +39,7 @@ object CurrentDateTime : Function<DateTime>(DateColumnType(true)) {
object CurrentDate : Function<DateTime>(DateColumnType(false)) {
override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder {
+when (currentDialect) {
is MariaDBDialect -> "curdate()"
is MysqlDialect -> "CURRENT_DATE()"
is SQLServerDialect -> "GETDATE()"
else -> "CURRENT_DATE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,10 @@ class JodaTimeDefaultsTest : DatabaseTestsBase() {

@Test
fun testDefaultExpressions01() {
fun abs(value: Int) = object : ExpressionWithColumnType<Int>() {
override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder { append("ABS($value)") }

override val columnType: IColumnType<Int> = IntegerColumnType()
}

val foo = object : IntIdTable("foo") {
val name = text("name")
val defaultDateTime = datetime("defaultDateTime").defaultExpression(CurrentDateTime)
val defaultDate = date("defaultDate").defaultExpression(CurrentDate)
val defaultInt = integer("defaultInteger").defaultExpression(abs(-100))
}

withTables(foo) {
Expand All @@ -245,7 +238,6 @@ class JodaTimeDefaultsTest : DatabaseTestsBase() {

assertEquals(today, result[foo.defaultDateTime].withTimeAtStartOfDay())
assertEquals(today, result[foo.defaultDate])
assertEquals(100, result[foo.defaultInt])
}
}

Expand Down Expand Up @@ -428,6 +420,7 @@ class JodaTimeDefaultsTest : DatabaseTestsBase() {
fun testConsistentSchemeWithFunctionAsDefaultExpression() {
val foo = object : IntIdTable("foo") {
val name = text("name")
val defaultDate = date("default_date").defaultExpression(CurrentDate)
val defaultDateTime = datetime("defaultDateTime").defaultExpression(CurrentDateTime)
}

Expand All @@ -436,6 +429,7 @@ class JodaTimeDefaultsTest : DatabaseTestsBase() {
SchemaUtils.create(foo)

val actual = SchemaUtils.statementsRequiredToActualizeScheme(foo)

assertTrue(actual.isEmpty())
} finally {
SchemaUtils.drop(foo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import kotlinx.datetime.LocalTime
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.Function
import org.jetbrains.exposed.sql.vendors.H2Dialect
import org.jetbrains.exposed.sql.vendors.MariaDBDialect
import org.jetbrains.exposed.sql.vendors.MysqlDialect
import org.jetbrains.exposed.sql.vendors.SQLServerDialect
import org.jetbrains.exposed.sql.vendors.currentDialect
Expand Down Expand Up @@ -107,7 +108,8 @@ object CurrentTimestampWithTimeZone : CurrentTimestampBase<OffsetDateTime>(Kotli
object CurrentDate : Function<LocalDate>(KotlinLocalDateColumnType.INSTANCE) {
override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder {
+when (currentDialect) {
is MysqlDialect -> "CURRENT_DATE()"
is MariaDBDialect -> "curdate()"
is MysqlDialect -> "CURRENT_DATE"
is SQLServerDialect -> "GETDATE()"
else -> "CURRENT_DATE"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,10 @@ class DefaultsTest : DatabaseTestsBase() {

@Test
fun testDefaultExpressions01() {
fun abs(value: Int) = object : ExpressionWithColumnType<Int>() {
override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder { append("ABS($value)") }

override val columnType: IColumnType<Int> = IntegerColumnType()
}

val foo = object : IntIdTable("foo") {
val name = text("name")
val defaultDateTime = datetime("defaultDateTime").defaultExpression(CurrentDateTime)
val defaultDate = date("defaultDate").defaultExpression(CurrentDate)
val defaultInt = integer("defaultInteger").defaultExpression(abs(-100))
}

withTables(foo) {
Expand All @@ -311,7 +304,6 @@ class DefaultsTest : DatabaseTestsBase() {

assertEquals(today, result[foo.defaultDateTime].date)
assertEquals(today, result[foo.defaultDate])
assertEquals(100, result[foo.defaultInt])
}
}

Expand Down Expand Up @@ -372,8 +364,7 @@ class DefaultsTest : DatabaseTestsBase() {
val foo = object : IntIdTable("foo") {
val name = text("name")
val defaultDate = date("default_date").defaultExpression(CurrentDate)
val defaultDateTime1 = datetime("default_date_time_1").defaultExpression(CurrentDateTime)
val defaultDateTime2 = datetime("default_date_time_2").defaultExpression(CurrentDateTime)
val defaultDateTime = datetime("default_date_time").defaultExpression(CurrentDateTime)
val defaultTimeStamp = timestamp("default_time_stamp").defaultExpression(CurrentTimestamp)
}

Expand All @@ -383,16 +374,7 @@ class DefaultsTest : DatabaseTestsBase() {

val actual = SchemaUtils.statementsRequiredToActualizeScheme(foo)

if (currentDialectTest is MysqlDialect) {
// MySQL and MariaDB do not support CURRENT_DATE as default
// so the column is created with a NULL marker, which correctly triggers 1 alter statement
val tableName = foo.nameInDatabaseCase()
val dateColumnName = foo.defaultDate.nameInDatabaseCase()
val alter = "ALTER TABLE $tableName MODIFY COLUMN $dateColumnName DATE NULL"
assertEquals(alter, actual.single())
} else {
assertTrue(actual.isEmpty())
}
assertTrue(actual.isEmpty())
} finally {
SchemaUtils.drop(foo)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import kotlinx.datetime.*
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import org.jetbrains.exposed.dao.id.IntIdTable
import org.jetbrains.exposed.dao.id.LongIdTable
import org.jetbrains.exposed.exceptions.UnsupportedByDialectException
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.SqlExpressionBuilder.between
Expand Down Expand Up @@ -610,6 +611,17 @@ class KotlinTimeTests : DatabaseTestsBase() {
)
}
}

@Test
fun testCurrentDateAsDefaultExpression() {
val testTable = object : LongIdTable("test_table") {
val date: Column<LocalDate> = date("date").index().defaultExpression(CurrentDate)
}
withTables(testTable) {
val statements = SchemaUtils.statementsRequiredForDatabaseMigration(testTable)
assertTrue(statements.isEmpty())
}
}
}

fun <T> assertEqualDateTime(d1: T?, d2: T?) {
Expand Down

0 comments on commit bea4d78

Please sign in to comment.