Skip to content

Commit

Permalink
Remove redundant schema visitors
Browse files Browse the repository at this point in the history
  • Loading branch information
dhpiggott committed Mar 11, 2024
1 parent 2695543 commit 3942bb2
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 168 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -308,48 +308,6 @@ class DocumentDecoderSchemaVisitor(
}
}

trait UnknownFieldsDecoder[A] { self =>
def apply(
history: List[PayloadPath.Segment],
unknownFields: Map[String, Document]
): A
}

object UnknownFieldsDecoder
extends SchemaVisitor.Default[UnknownFieldsDecoder] { self =>

override def default[A]: UnknownFieldsDecoder[A] = (history, _) =>
throw PayloadError(
PayloadPath(history.reverse),
"Json document",
"Expected Json Shape: Object"
)

override def primitive[P](
shapeId: ShapeId,
hints: Hints,
tag: Primitive[P]
): UnknownFieldsDecoder[P] = (history, unknownFields) =>
tag match {
case PDocument => Document.DObject(unknownFields)
case _ =>
throw PayloadError(
PayloadPath(history.reverse),
"Json document",
"Expected Json Shape: Object"
)
}

override def option[A](
schema: Schema[A]
): UnknownFieldsDecoder[Option[A]] = {
val decoder = schema.compile(self)
(history, unknownFields) =>
if (unknownFields.isEmpty) None
else Some(decoder(history, unknownFields))
}
}

override def struct[S](
shapeId: ShapeId,
hints: Hints,
Expand All @@ -368,14 +326,25 @@ class DocumentDecoderSchemaVisitor(
Map[String, Document]
) => Unit =
if (isForUnknownFieldRetention(field)) {
val unknownFieldsDecoder = UnknownFieldsDecoder(field.schema)
// TODO: Lift out.
val unknownFieldsDecoder = Document.Decoder.fromSchema(field.schema)
(
pp: List[PayloadPath.Segment],
buffer: Any => Unit,
fields: Map[String, Document]
) => {
val unknownFields = fields -- knownFieldLabels
buffer(unknownFieldsDecoder(pp, unknownFields))
buffer(
unknownFieldsDecoder
.decode(Document.DObject(unknownFields))
.getOrElse(
throw new PayloadError(
PayloadPath(pp.reverse),
"Json document",
"Expected Json Shape: Object"
)
)
)
}
} else {
val jsonLabel = jsonLabelOrLabel(field)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,42 +182,6 @@ class DocumentEncoderSchemaVisitor(
from(e => DString(total(e).stringValue))
}

trait UnknownFieldsEncoder[A] {
def apply(a: A): Map[String, Document]
}

object UnknownFieldsEncoder
extends SchemaVisitor.Default[UnknownFieldsEncoder] { self =>

override def default[A]: UnknownFieldsEncoder[A] = _ => Map.empty

override def primitive[P](
shapeId: ShapeId,
hints: Hints,
tag: Primitive[P]
): UnknownFieldsEncoder[P] = document =>
tag match {
case PDocument =>
document match {
case Document.DObject(values) => values
case _ => Map.empty
}

case _ =>
Map.empty
}

override def option[A](
schema: Schema[A]
): UnknownFieldsEncoder[Option[A]] = {
val encoder = self(schema)
locally {
case Some(a) => encoder.apply(a)
case None => Map.empty
}
}
}

override def struct[S](
shapeId: ShapeId,
hints: Hints,
Expand All @@ -237,8 +201,13 @@ class DocumentEncoderSchemaVisitor(
field.hints
.has(UnknownDocumentFieldRetention)
) {
val unknownFieldsEncoder = UnknownFieldsEncoder(field.schema)
(s, builder) => builder ++= unknownFieldsEncoder(field.get(s))
// TODO: Lift out.
val unknownFieldsEncoder = Document.Encoder.fromSchema(field.schema)
(s, builder) =>
unknownFieldsEncoder.encode(field.get(s)) match {
case Document.DObject(values) => builder ++= values
case _ => ()
}
} else {
val encoder = apply(field.schema)
val jsonLabel = field.hints
Expand Down
123 changes: 37 additions & 86 deletions modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ private[smithy4s] class SchemaVisitorJCodec(
}

private def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
Expand Down Expand Up @@ -578,7 +578,7 @@ private[smithy4s] class SchemaVisitorJCodec(
out.encodeError("Cannot use vectors as keys")

private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
Expand Down Expand Up @@ -626,7 +626,7 @@ private[smithy4s] class SchemaVisitorJCodec(
out.encodeError("Cannot use vectors as keys")

private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
Expand Down Expand Up @@ -687,7 +687,7 @@ private[smithy4s] class SchemaVisitorJCodec(
out.encodeError("Cannot use vectors as keys")

private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
Expand Down Expand Up @@ -734,7 +734,7 @@ private[smithy4s] class SchemaVisitorJCodec(
out.encodeError("Cannot use vectors as keys")

private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
Expand Down Expand Up @@ -793,7 +793,7 @@ private[smithy4s] class SchemaVisitorJCodec(
out.encodeError("Cannot use maps as keys")

private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
Expand Down Expand Up @@ -864,7 +864,7 @@ private[smithy4s] class SchemaVisitorJCodec(
out.encodeError("Cannot use maps as keys")

private def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
Expand Down Expand Up @@ -1285,39 +1285,6 @@ private[smithy4s] class SchemaVisitorJCodec(
case Some(x) => x.value
}

trait UnknownFieldsDecoder[A] { self =>
def apply(
in: JsonReader,
unknownFields: util.HashMap[String, Document]
): A
}

object UnknownFieldsDecoder
extends SchemaVisitor.Default[UnknownFieldsDecoder] { self =>

override def default[A]: UnknownFieldsDecoder[A] = (in, _) =>
in.decodeError("Expected JSON document")

override def primitive[P](
shapeId: ShapeId,
hints: Hints,
tag: Primitive[P]
): UnknownFieldsDecoder[P] = (in, unknownFields) =>
tag match {
case PDocument => Document.DObject(unknownFields.asScala.toMap)
case _ => in.decodeError("Expected JSON document")
}

override def option[A](
schema: Schema[A]
): UnknownFieldsDecoder[Option[A]] = {
val decoder = schema.compile(self)
(in, unknownFields) =>
if (unknownFields.isEmpty) None
else Some(decoder(in, unknownFields))
}
}

private type Handler = (Cursor, JsonReader, util.HashMap[String, Any]) => Unit

private def fieldHandler[Z, A](
Expand All @@ -1336,54 +1303,22 @@ private[smithy4s] class SchemaVisitorJCodec(
)
}

trait UnknownFieldsEncoder[A] {
def apply(a: A): JsonWriter => Unit
}

object UnknownFieldsEncoder
extends SchemaVisitor.Default[UnknownFieldsEncoder] { self =>

override def default[A]: UnknownFieldsEncoder[A] = _ => _ => ()

override def primitive[P](
shapeId: ShapeId,
hints: Hints,
tag: Primitive[P]
): UnknownFieldsEncoder[P] = document =>
tag match {
case PDocument =>
document match {
case Document.DObject(values) =>
jsonWriter =>
values.foreach { case (key, value) =>
jsonWriter.writeKey(key)
documentJCodec.encodeValue(value, jsonWriter)
}

case _ =>
_ => ()
}

case _ => _ => ()
}

override def option[A](
schema: Schema[A]
): UnknownFieldsEncoder[Option[A]] = {
val encoder = self(schema)
locally {
case Some(a) => encoder.apply(a)
case None => _ => ()
}
}
}

private def fieldEncoder[Z, A](
field: Field[Z, A]
): (Z, JsonWriter) => Unit =
if (isForUnknownFieldRetention(field)) {
val unknownFieldsEncoder = UnknownFieldsEncoder(field.schema)
(z: Z, out: JsonWriter) => unknownFieldsEncoder(field.get(z))(out)
val unknownFieldsEncoder = Document.Encoder.fromSchema(field.schema)
(z: Z, out: JsonWriter) =>
unknownFieldsEncoder.encode(field.get(z)) match {
case Document.DObject(values) =>
values.foreach { case (key, value) =>
out.writeKey(key)
documentJCodec.encodeValue(value, out)
}

case _ =>
()
}
} else {
val codec = apply(field.schema)
val jsonLabel = jsonLabelOrLabel(field)
Expand Down Expand Up @@ -1426,6 +1361,13 @@ private[smithy4s] class SchemaVisitorJCodec(
case (field, _, _) => isForUnknownFieldRetention(field)
}

private[this] val handlers =
new util.HashMap[String, Handler](fields.length << 1, 0.5f) {
fields.foreach { case (field, jsonLabel, _) =>
put(jsonLabel, fieldHandler(field))
}
}

private[this] val handlers =
new util.HashMap[String, Handler](fields.length << 1, 0.5f) {
fields.foreach { case (field, jsonLabel, _) =>
Expand Down Expand Up @@ -1468,8 +1410,17 @@ private[smithy4s] class SchemaVisitorJCodec(
fields.foreach { case (field, jsonLabel, default) =>
values += {
if (isForUnknownFieldRetention(field)) {
val unknownFieldsDecoder = UnknownFieldsDecoder(field.schema)
unknownFieldsDecoder(in, unknownFields)
// TODO: Lift out.
val unknownFieldsDecoder =
Document.Decoder.fromSchema(field.schema)
unknownFieldsDecoder
.decode(Document.DObject(unknownFields.asScala.toMap))
.getOrElse(
cursor.payloadError(
this,
"Expected JSON document"
)
)
} else {
val value = knownFields.get(field.label)
if (value == null) {
Expand Down

0 comments on commit 3942bb2

Please sign in to comment.