Skip to content

Commit

Permalink
[SPARK-46841][SQL] Add collation support for ICU locales and collatio…
Browse files Browse the repository at this point in the history
…n specifiers

### What changes were proposed in this pull request?

Languages and localization for collations are supported by ICU library. Collation naming format is as follows:
```
<2-letter language code>[_<4-letter script>][_<3-letter country code>][_specifier_specifier...]
```
Locale specifier consists of the first part of collation name (language + script + country). Locale specifiers need to be stable across ICU versions; to keep existing ids and names invariant we introduce golden file will locale table which should case CI failure on any silent changes.

Currently supported optional specifiers:

- `CS`/`CI` - case sensitivity, default is case-sensitive; supported by configuring ICU collation levels
- `AS`/`AI` - accent sensitivity, default is accent-sensitive; supported by configuring ICU collation levels

User can use collation specifiers in any order except of locale which is mandatory and must go first. There is a one-to-one mapping between collation ids and collation names defined in `CollationFactory`.

### Why are the changes needed?

To add languages and localization support for collations.

### Does this PR introduce _any_ user-facing change?

Yes, it adds new predefined collations.

### How was this patch tested?

Added checks to `CollationFactorySuite` and ICU locale map golden file.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#46180 from nikolamand-db/SPARK-46841.

Authored-by: Nikola Mandic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
nikolamand-db authored and cloud-fan committed May 28, 2024
1 parent a78ef73 commit 7fe1b93
Show file tree
Hide file tree
Showing 27 changed files with 1,388 additions and 236 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ package org.apache.spark.unsafe.types
import scala.collection.parallel.immutable.ParSeq
import scala.jdk.CollectionConverters.MapHasAsScala

import com.ibm.icu.util.ULocale

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.util.CollationFactory.fetchCollation
// scalastyle:off
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.must.Matchers
Expand All @@ -30,31 +33,95 @@ import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8}

class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ignore funsuite
test("collationId stability") {
val utf8Binary = fetchCollation(0)
assert(INDETERMINATE_COLLATION_ID == -1)

assert(UTF8_BINARY_COLLATION_ID == 0)
val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID)
assert(utf8Binary.collationName == "UTF8_BINARY")
assert(utf8Binary.supportsBinaryEquality)

val utf8BinaryLcase = fetchCollation(1)
assert(UTF8_BINARY_LCASE_COLLATION_ID == 1)
val utf8BinaryLcase = fetchCollation(UTF8_BINARY_LCASE_COLLATION_ID)
assert(utf8BinaryLcase.collationName == "UTF8_BINARY_LCASE")
assert(!utf8BinaryLcase.supportsBinaryEquality)

val unicode = fetchCollation(2)
assert(UNICODE_COLLATION_ID == (1 << 29))
val unicode = fetchCollation(UNICODE_COLLATION_ID)
assert(unicode.collationName == "UNICODE")
assert(unicode.supportsBinaryEquality);
assert(unicode.supportsBinaryEquality)

val unicodeCi = fetchCollation(3)
assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17)))
val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID)
assert(unicodeCi.collationName == "UNICODE_CI")
assert(!unicodeCi.supportsBinaryEquality)
}

test("fetch invalid collation name") {
val error = intercept[SparkException] {
fetchCollation("UTF8_BS")
test("UTF8_BINARY and ICU root locale collation names") {
// Collation name already normalized.
Seq(
"UTF8_BINARY",
"UTF8_BINARY_LCASE",
"UNICODE",
"UNICODE_CI",
"UNICODE_AI",
"UNICODE_CI_AI"
).foreach(collationName => {
val col = fetchCollation(collationName)
assert(col.collationName == collationName)
})
// Collation name normalization.
Seq(
// ICU root locale.
("UNICODE_CS", "UNICODE"),
("UNICODE_CS_AS", "UNICODE"),
("UNICODE_CI_AS", "UNICODE_CI"),
("UNICODE_AI_CS", "UNICODE_AI"),
("UNICODE_AI_CI", "UNICODE_CI_AI"),
// Randomized case collation names.
("utf8_binary", "UTF8_BINARY"),
("UtF8_binARy_LcasE", "UTF8_BINARY_LCASE"),
("unicode", "UNICODE"),
("UnICoDe_cs_aI", "UNICODE_AI")
).foreach{
case (name, normalized) =>
val col = fetchCollation(name)
assert(col.collationName == normalized)
}
}

test("fetch invalid UTF8_BINARY and ICU root locale collation names") {
Seq(
"UTF8_BINARY_CS",
"UTF8_BINARY_AS",
"UTF8_BINARY_CS_AS",
"UTF8_BINARY_AS_CS",
"UTF8_BINARY_CI",
"UTF8_BINARY_AI",
"UTF8_BINARY_CI_AI",
"UTF8_BINARY_AI_CI",
"UTF8_BS",
"BINARY_UTF8",
"UTF8_BINARY_A",
"UNICODE_X",
"UNICODE_CI_X",
"UNICODE_LCASE_X",
"UTF8_UNICODE",
"UTF8_BINARY_UNICODE",
"CI_UNICODE",
"LCASE_UNICODE",
"UNICODE_UNSPECIFIED",
"UNICODE_CI_UNSPECIFIED",
"UNICODE_UNSPECIFIED_CI_UNSPECIFIED",
"UNICODE_INDETERMINATE",
"UNICODE_CI_INDETERMINATE"
).foreach(collationName => {
val error = intercept[SparkException] {
fetchCollation(collationName)
}

assert(error.getErrorClass === "COLLATION_INVALID_NAME")
assert(error.getMessageParameters.asScala ===
Map("proposal" -> "UTF8_BINARY", "collationName" -> "UTF8_BS"))
assert(error.getErrorClass === "COLLATION_INVALID_NAME")
assert(error.getMessageParameters.asScala === Map("collationName" -> collationName))
})
}

case class CollationTestCase[R](collationName: String, s1: String, s2: String, expectedResult: R)
Expand Down Expand Up @@ -152,4 +219,238 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig
}
})
}

test("test collation caching") {
Seq(
"UTF8_BINARY",
"UTF8_BINARY_LCASE",
"UNICODE",
"UNICODE_CI",
"UNICODE_AI",
"UNICODE_CI_AI",
"UNICODE_AI_CI"
).foreach(collationId => {
val col1 = fetchCollation(collationId)
val col2 = fetchCollation(collationId)
assert(col1 eq col2) // Check for reference equality.
})
}

test("collations with ICU non-root localization") {
Seq(
// Language only.
"en",
"en_CS",
"en_CI",
"en_AS",
"en_AI",
// Language + 3-letter country code.
"en_USA",
"en_USA_CS",
"en_USA_CI",
"en_USA_AS",
"en_USA_AI",
// Language + script code.
"sr_Cyrl",
"sr_Cyrl_CS",
"sr_Cyrl_CI",
"sr_Cyrl_AS",
"sr_Cyrl_AI",
// Language + script code + 3-letter country code.
"sr_Cyrl_SRB",
"sr_Cyrl_SRB_CS",
"sr_Cyrl_SRB_CI",
"sr_Cyrl_SRB_AS",
"sr_Cyrl_SRB_AI"
).foreach(collationICU => {
val col = fetchCollation(collationICU)
assert(col.collator.getLocale(ULocale.VALID_LOCALE) != ULocale.ROOT)
})
}

test("invalid names of collations with ICU non-root localization") {
Seq(
"en_US", // Must use 3-letter country code
"enn",
"en_AAA",
"en_Something",
"en_Something_USA",
"en_LCASE",
"en_UCASE",
"en_CI_LCASE",
"en_CI_UCASE",
"en_CI_UNSPECIFIED",
"en_USA_UNSPECIFIED",
"en_USA_UNSPECIFIED_CI",
"en_INDETERMINATE",
"en_USA_INDETERMINATE",
"en_Latn_USA", // Use en_USA instead.
"en_Cyrl_USA",
"en_USA_AAA",
"sr_Cyrl_SRB_AAA",
// Invalid ordering of language, script and country code.
"USA_en",
"sr_SRB_Cyrl",
"SRB_sr",
"SRB_sr_Cyrl",
"SRB_Cyrl_sr",
"Cyrl_sr",
"Cyrl_sr_SRB",
"Cyrl_SRB_sr",
// Collation specifiers in the middle of locale.
"CI_en",
"USA_CI_en",
"en_CI_USA",
"CI_sr_Cyrl_SRB",
"sr_CI_Cyrl_SRB",
"sr_Cyrl_CI_SRB",
"CI_Cyrl_sr",
"Cyrl_CI_sr",
"Cyrl_CI_sr_SRB",
"Cyrl_sr_CI_SRB"
).foreach(collationName => {
val error = intercept[SparkException] {
fetchCollation(collationName)
}

assert(error.getErrorClass === "COLLATION_INVALID_NAME")
assert(error.getMessageParameters.asScala === Map("collationName" -> collationName))
})
}

test("collations name normalization for ICU non-root localization") {
Seq(
("en_USA", "en_USA"),
("en_CS", "en"),
("en_AS", "en"),
("en_CS_AS", "en"),
("en_AS_CS", "en"),
("en_CI", "en_CI"),
("en_AI", "en_AI"),
("en_AI_CI", "en_CI_AI"),
("en_CI_AI", "en_CI_AI"),
("en_CS_AI", "en_AI"),
("en_AI_CS", "en_AI"),
("en_CI_AS", "en_CI"),
("en_AS_CI", "en_CI"),
("en_USA_AI_CI", "en_USA_CI_AI"),
// Randomized case.
("EN_USA", "en_USA"),
("SR_CYRL", "sr_Cyrl"),
("sr_cyrl_srb", "sr_Cyrl_SRB"),
("sR_cYRl_sRb", "sr_Cyrl_SRB")
).foreach {
case (name, normalized) =>
val col = fetchCollation(name)
assert(col.collationName == normalized)
}
}

test("invalid collationId") {
val badCollationIds = Seq(
INDETERMINATE_COLLATION_ID, // Indeterminate collation.
1 << 30, // User-defined collation range.
(1 << 30) | 1, // User-defined collation range.
(1 << 30) | (1 << 29), // User-defined collation range.
1 << 1, // UTF8_BINARY mandatory zero bit 1 breach.
1 << 2, // UTF8_BINARY mandatory zero bit 2 breach.
1 << 3, // UTF8_BINARY mandatory zero bit 3 breach.
1 << 4, // UTF8_BINARY mandatory zero bit 4 breach.
1 << 5, // UTF8_BINARY mandatory zero bit 5 breach.
1 << 6, // UTF8_BINARY mandatory zero bit 6 breach.
1 << 7, // UTF8_BINARY mandatory zero bit 7 breach.
1 << 8, // UTF8_BINARY mandatory zero bit 8 breach.
1 << 9, // UTF8_BINARY mandatory zero bit 9 breach.
1 << 10, // UTF8_BINARY mandatory zero bit 10 breach.
1 << 11, // UTF8_BINARY mandatory zero bit 11 breach.
1 << 12, // UTF8_BINARY mandatory zero bit 12 breach.
1 << 13, // UTF8_BINARY mandatory zero bit 13 breach.
1 << 14, // UTF8_BINARY mandatory zero bit 14 breach.
1 << 15, // UTF8_BINARY mandatory zero bit 15 breach.
1 << 16, // UTF8_BINARY mandatory zero bit 16 breach.
1 << 17, // UTF8_BINARY mandatory zero bit 17 breach.
1 << 18, // UTF8_BINARY mandatory zero bit 18 breach.
1 << 19, // UTF8_BINARY mandatory zero bit 19 breach.
1 << 20, // UTF8_BINARY mandatory zero bit 20 breach.
1 << 23, // UTF8_BINARY mandatory zero bit 23 breach.
1 << 24, // UTF8_BINARY mandatory zero bit 24 breach.
1 << 25, // UTF8_BINARY mandatory zero bit 25 breach.
1 << 26, // UTF8_BINARY mandatory zero bit 26 breach.
1 << 27, // UTF8_BINARY mandatory zero bit 27 breach.
1 << 28, // UTF8_BINARY mandatory zero bit 28 breach.
(1 << 29) | (1 << 12), // ICU mandatory zero bit 12 breach.
(1 << 29) | (1 << 13), // ICU mandatory zero bit 13 breach.
(1 << 29) | (1 << 14), // ICU mandatory zero bit 14 breach.
(1 << 29) | (1 << 15), // ICU mandatory zero bit 15 breach.
(1 << 29) | (1 << 18), // ICU mandatory zero bit 18 breach.
(1 << 29) | (1 << 19), // ICU mandatory zero bit 19 breach.
(1 << 29) | (1 << 20), // ICU mandatory zero bit 20 breach.
(1 << 29) | (1 << 21), // ICU mandatory zero bit 21 breach.
(1 << 29) | (1 << 22), // ICU mandatory zero bit 22 breach.
(1 << 29) | (1 << 23), // ICU mandatory zero bit 23 breach.
(1 << 29) | (1 << 24), // ICU mandatory zero bit 24 breach.
(1 << 29) | (1 << 25), // ICU mandatory zero bit 25 breach.
(1 << 29) | (1 << 26), // ICU mandatory zero bit 26 breach.
(1 << 29) | (1 << 27), // ICU mandatory zero bit 27 breach.
(1 << 29) | (1 << 28), // ICU mandatory zero bit 28 breach.
(1 << 29) | 0xFFFF // ICU with invalid locale id.
)
badCollationIds.foreach(collationId => {
// Assumptions about collation id will break and assert statement will fail.
intercept[AssertionError](fetchCollation(collationId))
})
}

test("repeated and/or incompatible specifiers in collation name") {
Seq(
"UTF8_BINARY_LCASE_LCASE",
"UNICODE_CS_CS",
"UNICODE_CI_CI",
"UNICODE_CI_CS",
"UNICODE_CS_CI",
"UNICODE_AS_AS",
"UNICODE_AI_AI",
"UNICODE_AS_AI",
"UNICODE_AI_AS",
"UNICODE_AS_CS_AI",
"UNICODE_CS_AI_CI",
"UNICODE_CS_AS_CI_AI"
).foreach(collationName => {
val error = intercept[SparkException] {
fetchCollation(collationName)
}

assert(error.getErrorClass === "COLLATION_INVALID_NAME")
assert(error.getMessageParameters.asScala === Map("collationName" -> collationName))
})
}

test("basic ICU collator checks") {
Seq(
CollationTestCase("UNICODE_CI", "a", "A", true),
CollationTestCase("UNICODE_CI", "a", "å", false),
CollationTestCase("UNICODE_CI", "a", "Å", false),
CollationTestCase("UNICODE_AI", "a", "A", false),
CollationTestCase("UNICODE_AI", "a", "å", true),
CollationTestCase("UNICODE_AI", "a", "Å", false),
CollationTestCase("UNICODE_CI_AI", "a", "A", true),
CollationTestCase("UNICODE_CI_AI", "a", "å", true),
CollationTestCase("UNICODE_CI_AI", "a", "Å", true)
).foreach(testCase => {
val collation = fetchCollation(testCase.collationName)
assert(collation.equalsFunction(toUTF8(testCase.s1), toUTF8(testCase.s2)) ==
testCase.expectedResult)
})
Seq(
CollationTestCase("en", "a", "A", -1),
CollationTestCase("en_CI", "a", "A", 0),
CollationTestCase("en_AI", "a", "å", 0),
CollationTestCase("sv", "Kypper", "Köpfe", -1),
CollationTestCase("de", "Kypper", "Köpfe", 1)
).foreach(testCase => {
val collation = fetchCollation(testCase.collationName)
val result = collation.comparator.compare(toUTF8(testCase.s1), toUTF8(testCase.s2))
assert(Integer.signum(result) == testCase.expectedResult)
})
}
}
4 changes: 2 additions & 2 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@
},
"COLLATION_INVALID_NAME" : {
"message" : [
"The value <collationName> does not represent a correct collation name. Suggested valid collation name: [<proposal>]."
"The value <collationName> does not represent a correct collation name."
],
"sqlState" : "42704"
},
Expand Down Expand Up @@ -1921,7 +1921,7 @@
"subClass" : {
"DEFAULT_COLLATION" : {
"message" : [
"Cannot resolve the given default collation. Did you mean '<proposal>'?"
"Cannot resolve the given default collation."
]
},
"TIME_ZONE" : {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.{functions => fn}
import org.apache.spark.sql.avro.{functions => avroFn}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
Expand Down Expand Up @@ -699,7 +700,8 @@ class PlanGenerationTestSuite
}

test("select collated string") {
val schema = StructType(StructField("s", StringType(1)) :: Nil)
val schema = StructType(
StructField("s", StringType(CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID)) :: Nil)
createLocalRelation(schema.catalogString).select("s")
}

Expand Down
Loading

0 comments on commit 7fe1b93

Please sign in to comment.