Skip to content

Commit

Permalink
Fix deserialization of missing fields in gRPC
Browse files Browse the repository at this point in the history
Mockingbird wrong handled message if sender didn't fill a field in proto3
syntax. The right behavior is using default value for missing fields, but
mockingbird returned error.
  • Loading branch information
ashashev authored and danslapman committed Feb 5, 2024
1 parent 299a5d6 commit bd75f54
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ object GrpcExractor {
def buildMessageDefinition(gm: GrpcMessageSchema): MessageDefinition = {
val builder = MessageDefinition.newBuilder(gm.name)

gm.fields.foreach {
case f if f.label == GrpcLabel.Optional =>
val oneOfBuilder = builder.addOneof(s"_${f.name}")
oneOfBuilder.addField(f.typeName, f.name, f.order)
case f =>
builder.addField(f.label.entryName, f.typeName, f.name, f.order)
}

gm.oneofs.getOrElse(List.empty).foreach { oneof =>
val oneOfBuilder = builder.addOneof(oneof.name)
oneof.options.foreach { of =>
oneOfBuilder.addField(of.typeName, of.name, of.order)
}
}

gm.fields.foreach {
case f if f.isProto3Optional.getOrElse(false) =>
val oneOfBuilder = builder.addOneof(s"_${f.name}")
oneOfBuilder.addField(f.typeName, f.name, f.order)
case f =>
builder.addField(f.label.entryName, f.typeName, f.name, f.order)
}

gm.nested
.getOrElse(List.empty)
.foreach(
Expand All @@ -85,6 +85,8 @@ object GrpcExractor {
}
.build()

private val jsonPrinter = JsonFormat.printer().preservingProtoFieldNames().includingDefaultValueFields()

implicit class FromGrpcProtoDefinition(private val definition: GrpcProtoDefinition) extends AnyVal {
def toDynamicSchema: DynamicSchema = {
val registryBuilder: DynamicSchema.Builder = DynamicSchema.newBuilder()
Expand All @@ -111,7 +113,7 @@ object GrpcExractor {
def convertMessageToJson(bytes: Array[Byte], className: String): Task[Json] =
for {
message <- ZIO.attempt(parseFrom(bytes, className))
jsonString <- ZIO.attempt(JsonFormat.printer().preservingProtoFieldNames().print(message))
jsonString <- ZIO.attempt(jsonPrinter.print(message))
js <- ZIO.fromEither(parse(jsonString))
} yield js
}
Expand Down Expand Up @@ -140,8 +142,10 @@ object GrpcExractor {
)

private def message2messageSchema(message: DescriptorProtos.DescriptorProto): GrpcMessageSchema = {
val oneOfFields = message.getOneofDeclList.asScala.map(_.getName).toSet

val (fields, oneofs) = message.getFieldList.asScala.toList
.partition(f => !f.hasOneofIndex || isProto3OptionalField(f, message.getOneofDeclList.asScala.map(_.getName).toSet))
.partition(f => !f.hasOneofIndex || isProto3OptionalField(f, oneOfFields))

val nestedEnums = message.getEnumTypeList().asScala.toList
val nested = message.getNestedTypeList.asScala.toList
Expand All @@ -150,20 +154,12 @@ object GrpcExractor {
message.getName,
fields
.map { field =>
val label = GrpcLabel.withValue(field.getLabel.toString.split("_").last.toLowerCase).pipe { label =>
if (
label == GrpcLabel.Optional && (!isProto3OptionalField(
field,
message.getOneofDeclList.asScala.map(_.getName).toSet
))
) GrpcLabel.Required
else label
}
getGrpcField(field, label)
val label = GrpcLabel.withValue(field.getLabel.toString.split("_").last.toLowerCase)
getGrpcField(field, label, isProto3OptionalField(field, oneOfFields))
},
oneofs
.groupMap(_.getOneofIndex) { field =>
getGrpcField(field, GrpcLabel.Optional)
getGrpcField(field, GrpcLabel.Optional, false)
}
.map { case (index, fields) =>
GrpcOneOfSchema(
Expand Down Expand Up @@ -191,14 +187,19 @@ object GrpcExractor {
GrpcLabel.withValue(field.getLabel.toString.split("_").last.toLowerCase) == GrpcLabel.Optional &&
oneOfFields(s"_${field.getName}")

private def getGrpcField(field: DescriptorProtos.FieldDescriptorProto, label: GrpcLabel): GrpcField = {
private def getGrpcField(
field: DescriptorProtos.FieldDescriptorProto,
label: GrpcLabel,
isProto3Optional: Boolean
): GrpcField = {
val grpcType = getGrpcType(field)
GrpcField(
grpcType,
label,
getFieldType(field, grpcType == GrpcType.Custom),
field.getName,
field.getNumber
field.getNumber,
isProto3Optional.some,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ final case class GrpcField(
label: GrpcLabel,
typeName: String,
name: String,
order: Int
order: Int,
isProto3Optional: Option[Boolean],
)

@derive(
Expand Down
11 changes: 11 additions & 0 deletions backend/mockingbird/src/test/resources/not_optional_proto2.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto2";

enum Bar {
BAR_ZERO = 0;
BAR_ONE = 1;
}

message Foo {
required string field1 = 1;
required Bar field2 = 3;
}
11 changes: 11 additions & 0 deletions backend/mockingbird/src/test/resources/not_optional_proto3.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

enum Bar {
BAR_ZERO = 0;
BAR_ONE = 1;
}

message Foo {
string field1 = 1;
Bar field2 = 3;
}
11 changes: 11 additions & 0 deletions backend/mockingbird/src/test/resources/optional_proto2.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto2";

enum Bar {
BAR_ZERO = 0;
BAR_ONE = 1;
}

message Foo {
required string field1 = 1;
optional Bar field2 = 3;
}
11 changes: 11 additions & 0 deletions backend/mockingbird/src/test/resources/optional_proto3.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

enum Bar {
BAR_ZERO = 0;
BAR_ONE = 1;
}

message Foo {
string field1 = 1;
optional Bar field2 = 3;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package ru.tinkoff.tcb.protobuf

import com.github.os72.protobuf.dynamic.DynamicSchema
import com.google.protobuf.DynamicMessage
import com.google.protobuf.InvalidProtocolBufferException
import com.google.protobuf.util.JsonFormat
import io.circe.*
import zio.test.*
import zio.test.Assertion.*

import ru.tinkoff.tcb.mockingbird.grpc.GrpcExractor.FromDynamicSchema
import ru.tinkoff.tcb.mockingbird.grpc.GrpcExractor.FromGrpcProtoDefinition

object SerializationOptionalFieldsSpec extends ZIOSpecDefault {
val msgOptionalFieldAbsent = Array[Byte](0x0a, 0x04, 0x31, 0x71, 0x77, 0x65)
val msgOptionalFieldHasDefaultValue = Array[Byte](0x0a, 0x04, 0x31, 0x71, 0x77, 0x65, 0x18, 0x00)
val msgOptionalFieldHasAnotherValue = Array[Byte](0x0a, 0x04, 0x31, 0x71, 0x77, 0x65, 0x18, 0x01)
val typeName = "Foo"
val printer = JsonFormat.printer().includingDefaultValueFields().preservingProtoFieldNames().sortingMapKeys()

val optionalSyntax2 = "optional_proto2.proto"
val notOptionalSyntax2 = "not_optional_proto2.proto"
val optionalSyntax3 = "optional_proto3.proto"
val notOptionalSyntax3 = "not_optional_proto3.proto"

val field1Val = "1qwe"
val field2DefaultVal = "BAR_ZERO"
val field2AnotherVal = "BAR_ONE"

override def spec: Spec[TestEnvironment & Scope, Any] =
suite("Serialization of optional fields suite")(
test("An optional field in proto2 syntax: the field is absent") {
for {
msg <- parseWithProtoFromResource(optionalSyntax2, msgOptionalFieldAbsent)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2DefaultVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An optional field in proto2 syntax: the field has default value") {
for {
msg <- parseWithProtoFromResource(optionalSyntax2, msgOptionalFieldHasDefaultValue)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2DefaultVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An optional field in proto2 syntax: the field has another value") {
for {
msg <- parseWithProtoFromResource(optionalSyntax2, msgOptionalFieldHasAnotherValue)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2AnotherVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An required field in proto2 syntax: the field is absent") {
for {
result <- parseWithProtoFromResource(notOptionalSyntax2, msgOptionalFieldAbsent).exit
} yield assert(result)(failsWithA[InvalidProtocolBufferException])
},
test("An required field in proto2 syntax: the field has default value") {
for {
msg <- parseWithProtoFromResource(notOptionalSyntax2, msgOptionalFieldHasDefaultValue)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2DefaultVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An required field in proto2 syntax: the field has another value") {
for {
msg <- parseWithProtoFromResource(notOptionalSyntax2, msgOptionalFieldHasAnotherValue)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2AnotherVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An optional field in proto3 syntax: the field is absent") {
for {
msg <- parseWithProtoFromResource(optionalSyntax3, msgOptionalFieldAbsent)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An optional field in proto3 syntax: the field has default value") {
for {
msg <- parseWithProtoFromResource(optionalSyntax3, msgOptionalFieldHasDefaultValue)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2DefaultVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An optional field in proto3 syntax: the field has another value") {
for {
msg <- parseWithProtoFromResource(optionalSyntax3, msgOptionalFieldHasAnotherValue)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2AnotherVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An regular field in proto3 syntax: the field is absent") {
for {
msg <- parseWithProtoFromResource(notOptionalSyntax3, msgOptionalFieldAbsent)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2DefaultVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An regular field in proto3 syntax: the field has default value") {
for {
msg <- parseWithProtoFromResource(notOptionalSyntax3, msgOptionalFieldHasDefaultValue)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2DefaultVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
test("An regular field in proto3 syntax: the field has another value") {
for {
msg <- parseWithProtoFromResource(notOptionalSyntax3, msgOptionalFieldHasAnotherValue)
obtain <- parseJson(printer.print(msg))
expected <- parseJson(s"""{
| "field1": "$field1Val",
| "field2": "$field2AnotherVal"
|}""".stripMargin)
} yield assertTrue(obtain == expected)
},
)

def parseWithProtoFromResource(protoName: String, rawData: Array[Byte]) =
for {
schema <- getSchemaFromResource(protoName)
desc = schema.getMessageDescriptor(typeName)
result <- ZIO.attempt(DynamicMessage.parseFrom(desc, rawData))
} yield result

def getSchemaFromResource(name: String) =
for {
bytes <- Utils.getProtoDescriptionFromResource(name)
// We are checking reconstructed schema
schema <- ZIO.attempt(DynamicSchema.parseFrom(bytes).toGrpcProtoDefinition.toDynamicSchema)
} yield schema

def parseJson(js: String): Task[Json] =
ZIO.fromEither(parser.parse(js))
}

0 comments on commit bd75f54

Please sign in to comment.