Skip to content

Commit

Permalink
optimise TypeMatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
mkljakubowski committed Feb 13, 2025
1 parent 3581d79 commit 376d672
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 140 deletions.
21 changes: 4 additions & 17 deletions avrohugger-core/src/main/scala/input/NestedSchemaExtractor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,10 @@ object NestedSchemaExtractor {
fieldSchema.getType match {
case ARRAY => flattenSchema(fieldSchema.getElementType)
case MAP => flattenSchema(fieldSchema.getValueType)
case RECORD => {
// if the field schema is one that has already been stored, use that one
if (schemaStore.schemas.contains(fieldSchema.getFullName)) List()
// if we've already seen this schema (recursive schemas) don't traverse further
else extract(fieldSchema):+fieldSchema

}
case RECORD => extract(fieldSchema) :+ fieldSchema
case UNION => fieldSchema.getTypes().asScala.toList.flatMap(x => flattenSchema(x))
case ENUM => {
if (schemaStore.schemas.contains(fieldSchema.getFullName)) List()
else List(fieldSchema)
}
case FIXED => {
// if the field schema is one that has already been stored, use that one
if (schemaStore.schemas.contains(fieldSchema.getFullName)) List()
else List(fieldSchema)
}
case ENUM => List(fieldSchema)
case FIXED => List(fieldSchema)
case _ => List(fieldSchema)
}
}
Expand All @@ -70,7 +57,7 @@ object NestedSchemaExtractor {
}
}
// most-nested schemas should be compiled first
extract(schema):+schema
extract(schema) :+ schema
}
}

184 changes: 65 additions & 119 deletions avrohugger-core/src/main/scala/matchers/TypeMatcher.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package avrohugger
package matchers

import avrohugger.matchers.custom.{CustomNamespaceMatcher, CustomTypeMatcher}
import avrohugger.matchers.custom.{ CustomNamespaceMatcher, CustomTypeMatcher }
import avrohugger.stores.ClassStore
import avrohugger.types._
import treehugger.forest._
Expand All @@ -24,71 +24,74 @@ class TypeMatcher(
schema: Schema,
useFullName: Boolean = false
): Type = {
var typeMap = Map[String, Type]()

// May contain nested schemas that will use the same namespace as the
// top-level schema. Thus, when a field is parsed, the namespace is passed.
def matchType(schema: Schema): Type = {

schema.getType match {
case Schema.Type.ARRAY => {
val avroElement = schema.getElementType
val scalaElementType = toScalaType(classStore, namespace, avroElement)
val collectionType = CustomTypeMatcher.checkCustomArrayType(avroScalaTypes.array)
collectionType(scalaElementType)
}
case Schema.Type.MAP => {
val keyType = StringClass
val avroValueType = schema.getValueType
val scalaValueType = toScalaType(classStore, namespace, avroValueType)
TYPE_MAP(keyType, scalaValueType)
}
case Schema.Type.BOOLEAN => BooleanClass
case Schema.Type.DOUBLE => CustomTypeMatcher.checkCustomNumberType(avroScalaTypes.double)
case Schema.Type.FLOAT => CustomTypeMatcher.checkCustomNumberType(avroScalaTypes.float)
case Schema.Type.LONG =>
LogicalType.foldLogicalTypes(
schema = schema,
default = CustomTypeMatcher.checkCustomNumberType(avroScalaTypes.long)) {
case TimestampMillis => CustomTypeMatcher.checkCustomTimestampMillisType(avroScalaTypes.timestampMillis)
case TimestampMicros => CustomTypeMatcher.checkCustomTimestampMicrosType(avroScalaTypes.timestampMicros)
case LocalTimestampMicros => CustomTypeMatcher.checkCustomLocalTimestampMicrosType(avroScalaTypes.localTimestampMicros)
case LocalTimestampMillis => CustomTypeMatcher.checkCustomLocalTimestampMillisType(avroScalaTypes.localTimestampMillis)
case TimeMicros => CustomTypeMatcher.checkCustomTimeMicrosType(avroScalaTypes.timeMicros)
}
case Schema.Type.INT =>
LogicalType.foldLogicalTypes(
schema = schema,
default = CustomTypeMatcher.checkCustomNumberType(avroScalaTypes.int)) {
case Date => CustomTypeMatcher.checkCustomDateType(avroScalaTypes.date)
case TimeMillis => CustomTypeMatcher.checkCustomTimeMillisType(avroScalaTypes.timeMillis)
}
case Schema.Type.NULL => NullClass
case Schema.Type.STRING =>
LogicalType.foldLogicalTypes(
schema = schema,
default = StringClass) {
case UUID => RootClass.newClass(nme.createNameType("java.util.UUID"))
}
case Schema.Type.FIXED =>
RootClass.newClass(s"${schema.getNamespace()}.${classStore.generatedClasses(schema)}")
case Schema.Type.BYTES => CustomTypeMatcher.checkCustomDecimalType(avroScalaTypes.decimal, schema)
case Schema.Type.RECORD =>
{
val maybeNamespace = CustomNamespaceMatcher.checkCustomNamespace(
Option(schema.getNamespace()),
this,
maybeDefaultNamespace = Option(schema.getNamespace())
)
maybeNamespace match {
if (typeMap.contains(schema.getFullName)) {
typeMap(schema.getFullName)
}
else {
val tp: Type = schema.getType match {
case Schema.Type.ARRAY =>
val avroElement = schema.getElementType
val scalaElementType = toScalaType(classStore, namespace, avroElement)
val collectionType = CustomTypeMatcher.checkCustomArrayType(avroScalaTypes.array)
collectionType(scalaElementType)
case Schema.Type.MAP =>
val keyType = StringClass
val avroValueType = schema.getValueType
val scalaValueType = toScalaType(classStore, namespace, avroValueType)
TYPE_MAP(keyType, scalaValueType)
case Schema.Type.BOOLEAN => BooleanClass
case Schema.Type.DOUBLE => CustomTypeMatcher.checkCustomNumberType(avroScalaTypes.double)
case Schema.Type.FLOAT => CustomTypeMatcher.checkCustomNumberType(avroScalaTypes.float)
case Schema.Type.LONG =>
LogicalType.foldLogicalTypes(
schema = schema,
default = CustomTypeMatcher.checkCustomNumberType(avroScalaTypes.long)) {
case TimestampMillis => CustomTypeMatcher.checkCustomTimestampMillisType(avroScalaTypes.timestampMillis)
case TimestampMicros => CustomTypeMatcher.checkCustomTimestampMicrosType(avroScalaTypes.timestampMicros)
case LocalTimestampMicros => CustomTypeMatcher.checkCustomLocalTimestampMicrosType(avroScalaTypes.localTimestampMicros)
case LocalTimestampMillis => CustomTypeMatcher.checkCustomLocalTimestampMillisType(avroScalaTypes.localTimestampMillis)
case TimeMicros => CustomTypeMatcher.checkCustomTimeMicrosType(avroScalaTypes.timeMicros)
}
case Schema.Type.INT =>
LogicalType.foldLogicalTypes(
schema = schema,
default = CustomTypeMatcher.checkCustomNumberType(avroScalaTypes.int)) {
case Date => CustomTypeMatcher.checkCustomDateType(avroScalaTypes.date)
case TimeMillis => CustomTypeMatcher.checkCustomTimeMillisType(avroScalaTypes.timeMillis)
}
case Schema.Type.NULL => NullClass
case Schema.Type.STRING =>
LogicalType.foldLogicalTypes(
schema = schema,
default = StringClass) {
case UUID => RootClass.newClass(nme.createNameType("java.util.UUID"))
}
case Schema.Type.FIXED =>
RootClass.newClass(s"${schema.getNamespace()}.${classStore.generatedClasses(schema)}")
case Schema.Type.BYTES => CustomTypeMatcher.checkCustomDecimalType(avroScalaTypes.decimal, schema)
case Schema.Type.RECORD =>
val maybeNamespace = CustomNamespaceMatcher.checkCustomNamespace(
Option(schema.getNamespace()),
this,
maybeDefaultNamespace = Option(schema.getNamespace())
)
maybeNamespace match {
case Some(ns) => s"${ns}.${schema.getName()}"
case None => schema.getName()
}
}
case Schema.Type.ENUM => CustomTypeMatcher.checkCustomEnumType(avroScalaTypes.`enum`, classStore, schema, useFullName)
case Schema.Type.UNION => {
//unions are represented as shapeless.Coproduct
val unionSchemas = schema.getTypes().asScala.toList
unionTypeImpl(unionSchemas, matchType)
}
case Schema.Type.ENUM => CustomTypeMatcher.checkCustomEnumType(avroScalaTypes.`enum`, classStore, schema, useFullName)
case Schema.Type.UNION =>
//unions are represented as shapeless.Coproduct
val unionSchemas = schema.getTypes().asScala.toList
unionTypeImpl(unionSchemas, matchType)
}
typeMap += (schema.getFullName -> tp)
tp
}
}

Expand All @@ -115,7 +118,7 @@ class TypeMatcher(
* value must match the first element of the union. Thus, for unions containing "null", the "null" is usually listed
* first, since the default value of such unions is typically null.)
*/
private[this] def unionTypeImpl(unionSchemas: List[Schema], typeMatcher: (Schema) => Type) : Type = {
private[this] def unionTypeImpl(unionSchemas: List[Schema], typeMatcher: (Schema) => Type): Type = {

def shapelessCoproductType(tp: Type*): forest.Type = {
val copTypes = tp.toList :+ typeRef(RootClass.newClass(newTypeName("CNil")))
Expand Down Expand Up @@ -172,61 +175,4 @@ class TypeMatcher(
if (includesNull) optionType(matchedType) else matchedType
}


//Scavro requires Java types be generated for mapping Java classes to Scala

val avroStringType = TYPE_REF("CharSequence")

def toJavaType(
classStore: ClassStore,
namespace: Option[String],
schema: Schema): Type = {
// The schema may contain nested schemas that will use the same namespace
// as the top-level schema. Thus, when a field is parsed, the namespace is
// passed in once
def matchType(schema: Schema): Type = {
def javaRename(schema: Schema) = {
"J" + classStore.generatedClasses(schema)
}

schema.getType match {
case Schema.Type.INT => TYPE_REF("java.lang.Integer")
case Schema.Type.DOUBLE => TYPE_REF("java.lang.Double")
case Schema.Type.FLOAT => TYPE_REF("java.lang.Float")
case Schema.Type.LONG => TYPE_REF("java.lang.Long")
case Schema.Type.BOOLEAN => TYPE_REF("java.lang.Boolean")
case Schema.Type.STRING => avroStringType
case Schema.Type.ARRAY => {
val avroElement = schema.getElementType
val elementType = toJavaType(classStore, namespace, avroElement)
TYPE_REF(REF("java.util.List") APPLYTYPE(elementType))
}
case Schema.Type.MAP => {
val keyType = avroStringType
val valueType = toJavaType(classStore, namespace, schema.getValueType)
TYPE_REF(REF("java.util.Map") APPLYTYPE(keyType, valueType))
}
case Schema.Type.NULL => TYPE_REF("java.lang.Void")
case Schema.Type.FIXED => sys.error("FIXED datatype not supported")
case Schema.Type.BYTES => TYPE_REF("java.nio.ByteBuffer")
case Schema.Type.RECORD => TYPE_REF(javaRename(schema))
case Schema.Type.ENUM => TYPE_REF(javaRename(schema))
case Schema.Type.UNION => {
val unionSchemas = schema.getTypes().asScala.toList
// unions are represented as Scala Option[T], and thus unions must be
// of two types, one of them NULL
val isTwoTypes = unionSchemas.length == 2
val oneTypeIsNull = unionSchemas.exists(_.getType == Schema.Type.NULL)
if (isTwoTypes && oneTypeIsNull) {
val maybeSchema = unionSchemas.find(_.getType != Schema.Type.NULL)
if (maybeSchema.isDefined ) matchType(maybeSchema.get)
else sys.error("no avro type found in this union")
}
else sys.error("unions not yet supported beyond nullable fields")
}
}
}

matchType(schema)
}
}
}
5 changes: 1 addition & 4 deletions avrohugger-core/src/main/scala/stores/SchemaStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@ import org.apache.avro.Schema
import java.util.concurrent.ConcurrentHashMap
import scala.jdk.CollectionConverters._

// this isn't used
class SchemaStore {

val schemas: scala.collection.concurrent.Map[String, Schema] =
new ConcurrentHashMap[String, Schema]().asScala

def accept(schema: Schema) =
schemas += (schema.getFullName -> schema)


}

0 comments on commit 376d672

Please sign in to comment.