Skip to content

Commit

Permalink
scalapb-json#160 - accept JSON and original file names in parser
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobejs committed Oct 24, 2021
1 parent c2c238c commit 8869356
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 56 deletions.
87 changes: 44 additions & 43 deletions core/shared/src/main/scala/scalapb_circe/JsonFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import scala.util.control.NonFatal
case class Formatter[T](writer: (Printer, T) => Json, parser: (Parser, Json) => T)

case class FormatRegistry(
messageFormatters: Map[Class[_], Formatter[_]] = Map.empty,
enumFormatters: Map[EnumDescriptor, Formatter[EnumValueDescriptor]] = Map.empty,
registeredCompanions: Seq[GenericCompanion] = Seq.empty
) {
messageFormatters: Map[Class[_], Formatter[_]] = Map.empty,
enumFormatters: Map[EnumDescriptor, Formatter[EnumValueDescriptor]] = Map.empty,
registeredCompanions: Seq[GenericCompanion] = Seq.empty
) {

def registerMessageFormatter[T <: GeneratedMessage](writer: (Printer, T) => Json, parser: (Parser, Json) => T)(
implicit ct: ClassTag[T]
Expand All @@ -31,13 +31,13 @@ case class FormatRegistry(
}

def registerEnumFormatter[E <: GeneratedEnum](
writer: (Printer, EnumValueDescriptor) => Json,
parser: (Parser, Json) => EnumValueDescriptor
)(implicit cmp: GeneratedEnumCompanion[E]): FormatRegistry = {
writer: (Printer, EnumValueDescriptor) => Json,
parser: (Parser, Json) => EnumValueDescriptor
)(implicit cmp: GeneratedEnumCompanion[E]): FormatRegistry = {
copy(enumFormatters = enumFormatters + (cmp.scalaDescriptor -> Formatter(writer, parser)))
}

def registerWriter[T <: GeneratedMessage: ClassTag](writer: T => Json, parser: Json => T): FormatRegistry = {
def registerWriter[T <: GeneratedMessage : ClassTag](writer: T => Json, parser: Json => T): FormatRegistry = {
registerMessageFormatter((p: Printer, t: T) => writer(t), (p: Parser, v: Json) => parser(v))
}

Expand All @@ -59,13 +59,13 @@ case class FormatRegistry(
}

class Printer(
includingDefaultValueFields: Boolean = false,
preservingProtoFieldNames: Boolean = false,
val formattingLongAsNumber: Boolean = false,
formattingEnumsAsNumber: Boolean = false,
formatRegistry: FormatRegistry = JsonFormat.DefaultRegistry,
val typeRegistry: TypeRegistry = TypeRegistry.empty
) {
includingDefaultValueFields: Boolean = false,
preservingProtoFieldNames: Boolean = false,
val formattingLongAsNumber: Boolean = false,
formattingEnumsAsNumber: Boolean = false,
formatRegistry: FormatRegistry = JsonFormat.DefaultRegistry,
val typeRegistry: TypeRegistry = TypeRegistry.empty
) {
def print[A](m: GeneratedMessage): String = {
toJson(m).noSpaces
}
Expand All @@ -83,7 +83,7 @@ class Printer(
if (includingDefaultValueFields) {
b += ((name, if (fd.isMapField) Json.obj() else Json.arr()))
}
case xs: Iterable[GeneratedMessage] @unchecked =>
case xs: Iterable[GeneratedMessage]@unchecked =>
if (fd.isMapField) {
val mapEntryDescriptor = fd.scalaType.asInstanceOf[ScalaType.Message].descriptor
val keyDescriptor = mapEntryDescriptor.findFieldByNumber(1).get
Expand Down Expand Up @@ -136,10 +136,10 @@ class Printer(
case v =>
if (
includingDefaultValueFields ||
!fd.isOptional ||
!fd.file.isProto3 ||
(v != scalapb_json.ScalapbJsonCommon.defaultValue(fd)) ||
fd.containingOneof.isDefined
!fd.isOptional ||
!fd.file.isProto3 ||
(v != scalapb_json.ScalapbJsonCommon.defaultValue(fd)) ||
fd.containingOneof.isDefined
) {
b += JField(name, serializeSingleValue(fd, v, formattingLongAsNumber))
}
Expand Down Expand Up @@ -203,14 +203,14 @@ class Printer(
}

class Parser(
preservingProtoFieldNames: Boolean = false,
formatRegistry: FormatRegistry = JsonFormat.DefaultRegistry,
val typeRegistry: TypeRegistry = TypeRegistry.empty
) {
preservingProtoFieldNames: Boolean = false,
formatRegistry: FormatRegistry = JsonFormat.DefaultRegistry,
val typeRegistry: TypeRegistry = TypeRegistry.empty
) {

def fromJsonString[A <: GeneratedMessage](
str: String
)(implicit cmp: GeneratedMessageCompanion[A]): A = {
str: String
)(implicit cmp: GeneratedMessageCompanion[A]): A = {
fromJson(io.circe.parser.parse(str).fold(throw _, identity))
}

Expand Down Expand Up @@ -278,7 +278,7 @@ class Parser(

val valueMap: Map[FieldDescriptor, PValue] = (for {
fd <- cmp.scalaDescriptor.fields
jsValue <- values.get(serializedName(fd)) if !jsValue.isNull
jsValue <- values.get(ScalapbJsonCommon.jsonName(fd)).orElse(values.get(fd.asProto.getName)) if !jsValue.isNull
} yield (fd, parseValue(fd, jsValue))).toMap

PMessage(valueMap)
Expand Down Expand Up @@ -306,10 +306,10 @@ class Parser(
}

protected def parseSingleValue(
containerCompanion: GeneratedMessageCompanion[_],
fd: FieldDescriptor,
value: Json
): PValue =
containerCompanion: GeneratedMessageCompanion[_],
fd: FieldDescriptor,
value: Json
): PValue =
fd.scalaType match {
case ScalaType.Enum(ed) =>
PEnum(formatRegistry.getEnumParser(ed) match {
Expand All @@ -331,6 +331,7 @@ class Parser(
}

object JsonFormat {

import com.google.protobuf.wrappers
import scalapb_json.ScalapbJsonCommon._

Expand Down Expand Up @@ -405,8 +406,8 @@ object JsonFormat {
.registerMessageFormatter[com.google.protobuf.any.Any](AnyFormat.anyWriter, AnyFormat.anyParser)

def primitiveWrapperWriter[T <: GeneratedMessage](implicit
cmp: GeneratedMessageCompanion[T]
): ((Printer, T) => Json) = {
cmp: GeneratedMessageCompanion[T]
): ((Printer, T) => Json) = {
val fieldDesc = cmp.scalaDescriptor.findFieldByNumber(1).get
(printer, t) =>
printer.serializeSingleValue(
Expand All @@ -417,8 +418,8 @@ object JsonFormat {
}

def primitiveWrapperParser[T <: GeneratedMessage](implicit
cmp: GeneratedMessageCompanion[T]
): ((Parser, Json) => T) = {
cmp: GeneratedMessageCompanion[T]
): ((Parser, Json) => T) = {
val fieldDesc = cmp.scalaDescriptor.findFieldByNumber(1).get
(parser, jv) =>
cmp.messageReads.read(
Expand All @@ -442,15 +443,15 @@ object JsonFormat {

def toJson[A <: GeneratedMessage](m: A): Json = printer.toJson(m)

def fromJson[A <: GeneratedMessage: GeneratedMessageCompanion](value: Json): A = {
def fromJson[A <: GeneratedMessage : GeneratedMessageCompanion](value: Json): A = {
parser.fromJson(value)
}

def fromJsonString[A <: GeneratedMessage: GeneratedMessageCompanion](str: String): A = {
def fromJsonString[A <: GeneratedMessage : GeneratedMessageCompanion](str: String): A = {
parser.fromJsonString(str)
}

implicit def protoToDecoder[T <: GeneratedMessage: GeneratedMessageCompanion]: Decoder[T] =
implicit def protoToDecoder[T <: GeneratedMessage : GeneratedMessageCompanion]: Decoder[T] =
Decoder.instance { value =>
try {
Right(parser.fromJson(value.value))
Expand All @@ -464,11 +465,11 @@ object JsonFormat {
Encoder.instance(printer.toJson(_))

def parsePrimitive(
scalaType: ScalaType,
protoType: FieldDescriptorProto.Type,
value: Json,
onError: => PValue
): PValue = {
scalaType: ScalaType,
protoType: FieldDescriptorProto.Type,
value: Json,
onError: => PValue
): PValue = {
scalaType match {
case ScalaType.Int =>
value.fold(
Expand Down
13 changes: 0 additions & 13 deletions core/shared/src/test/scala/scalapb_circe/CodecSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,6 @@ class CodecSpec extends AnyFreeSpec with Matchers {
// Using asJson with an implicit printer includes the default value.
g.asJson mustBe Json.obj("numberOfStrings" -> Json.fromInt(0))
}

"decode using an implicit parser w/ non-standard settings" in {
implicit val parser: Parser = new Parser(preservingProtoFieldNames = true)

// Use the snake-case naming to define a Guitar Json object.
val j = Json.obj("number_of_strings" -> Json.fromInt(42))

// Using the regular JsonFormat parser decodes to the defaultInstance.
JsonFormat.fromJson[Guitar](j) mustBe Guitar.defaultInstance

// Using as[T] with an implicit parser decodes back to the original value (42).
j.as[Guitar] mustBe Right(Guitar(42))
}
}

"GeneratedEnum" - {
Expand Down
6 changes: 6 additions & 0 deletions core/shared/src/test/scala/scalapb_circe/JsonFormatSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ class JsonFormatSpec extends AnyFlatSpec with Matchers with OptionValues {
new Parser().fromJsonString[MyTest]("""{"optEnum":2}""") must be(MyTest(optEnum = Some(MyEnum.V2)))
}


"TestProto" should "parse original field names" in {
new Parser().fromJsonString[MyTest]("""{"opt_enum":1}""") must be(MyTest(optEnum = Some(MyEnum.V1)))
new Parser().fromJsonString[MyTest]("""{"opt_enum":2}""") must be(MyTest(optEnum = Some(MyEnum.V2)))
}

"PreservedTestJson" should "be TestProto when parsed from json" in {
new Parser(preservingProtoFieldNames = true).fromJsonString[MyTest](PreservedTestJson) must be(TestProto)
}
Expand Down

0 comments on commit 8869356

Please sign in to comment.