From 7b9414ba1ffaed23c85db9972b68f0d0c66245c9 Mon Sep 17 00:00:00 2001 From: tamayika Date: Wed, 7 Aug 2024 17:49:16 +0900 Subject: [PATCH] support comparison of message embed structs --- testing/protocmp/util_test.go | 51 +++++++++++++++++++++++++++++++++++ testing/protocmp/xform.go | 26 ++++++++++++++++-- 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/testing/protocmp/util_test.go b/testing/protocmp/util_test.go index fcc032633..426548dc6 100644 --- a/testing/protocmp/util_test.go +++ b/testing/protocmp/util_test.go @@ -30,6 +30,14 @@ func TestEqual(t *testing.T) { allTypesDesc := (*testpb.TestAllTypes)(nil).ProtoReflect().Descriptor() + type embedMessage struct { + *testpb.TestAllTypes + Extra int + } + type hasEmbedMessage struct { + Embed *embedMessage + } + // Test nil and empty messages of differing types. tests = append(tests, []test{{ x: (*testpb.TestAllTypes)(nil), @@ -146,6 +154,49 @@ func TestEqual(t *testing.T) { y: struct{ M proto.Message }{dynamicpb.NewMessage(allTypesDesc)}, opts: cmp.Options{Transform()}, want: true, + }, { + x: &hasEmbedMessage{}, + y: &hasEmbedMessage{}, + opts: cmp.Options{Transform()}, + want: true, + }, { + x: &hasEmbedMessage{ + Embed: &embedMessage{ + TestAllTypes: &testpb.TestAllTypes{ + DefaultInt64: proto.Int64(10), + }, + Extra: 10, + }, + }, + y: &hasEmbedMessage{ + Embed: &embedMessage{ + TestAllTypes: &testpb.TestAllTypes{ + DefaultInt64: proto.Int64(20), + }, + Extra: 10, + }, + }, + opts: cmp.Options{Transform()}, + want: false, + }, { + x: &hasEmbedMessage{ + Embed: &embedMessage{ + TestAllTypes: &testpb.TestAllTypes{ + DefaultInt64: proto.Int64(10), + }, + Extra: 10, + }, + }, + y: &hasEmbedMessage{ + Embed: &embedMessage{ + TestAllTypes: &testpb.TestAllTypes{ + DefaultInt64: proto.Int64(10), + }, + Extra: 20, + }, + }, + opts: cmp.Options{Transform()}, + want: false, }}...) // Test message values. diff --git a/testing/protocmp/xform.go b/testing/protocmp/xform.go index de29d9732..dc2ceceb8 100644 --- a/testing/protocmp/xform.go +++ b/testing/protocmp/xform.go @@ -212,7 +212,8 @@ func Transform(opts ...option) cmp.Option { return cmp.FilterPath(func(p cmp.Path) bool { ps := p.Last() if isMessageType(addrType(ps.Type())) { - return true + // Check if message is embed to support message embed struct + return !isMessageEmbed(ps.Type()) } // Check whether the concrete values of an interface both satisfy @@ -222,7 +223,9 @@ func Transform(opts ...option) cmp.Option { if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() { return false } - return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type())) + return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type())) && + // Check if message is embed to support message embed struct + !isMessageEmbed(addrType(vx.Elem().Type())) && !isMessageEmbed(addrType(vy.Elem().Type())) } return false @@ -247,6 +250,25 @@ func Transform(opts ...option) cmp.Option { })) } +func isMessageEmbed(t reflect.Type) bool { + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return false + } + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.Anonymous { + continue + } + if isMessageType(t.Field(i).Type) { + return true + } + } + return false +} + func isMessageType(t reflect.Type) bool { // Avoid transforming the Message itself. if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) {