From 93e4d89e972cb24eff53ad4ea61cbc710b890741 Mon Sep 17 00:00:00 2001 From: Evgeniy Fedotov Date: Wed, 15 Nov 2023 17:36:29 +0300 Subject: [PATCH] fix: ResetVT oneof --- features/pool/pool.go | 12 +- testproto/pool/pool_with_oneof.pb.go | 44 ++++-- testproto/pool/pool_with_oneof.proto | 1 + testproto/pool/pool_with_oneof_vtproto.pb.go | 140 +++++++++++++++++++ 4 files changed, 178 insertions(+), 19 deletions(-) diff --git a/features/pool/pool.go b/features/pool/pool.go index 337dc61..64d4ba5 100644 --- a/features/pool/pool.go +++ b/features/pool/pool.go @@ -45,7 +45,7 @@ func (p *pool) message(message *protogen.Message) { p.P(`var vtprotoPool_`, ccTypeName, ` = `, p.Ident("sync", "Pool"), `{`) p.P(`New: func() interface{} {`) - p.P(`return &`, message.GoIdent, `{}`) + p.P(`return &`, ccTypeName, `{}`) p.P(`},`) p.P(`}`) @@ -68,10 +68,12 @@ func (p *pool) message(message *protogen.Message) { } p.P(fmt.Sprintf("f%d", len(saved)), ` := m.`, fieldName, `[:0]`) saved = append(saved, field) - } else if field.Oneof != nil && p.ShouldPool(message) && p.ShouldPool(field.Message) { - p.P(`if oneof, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent, `); ok {`) - p.P(`oneof.`, fieldName, `.ReturnToVTPool()`) - p.P(`}`) + } else if field.Oneof != nil { + if p.ShouldPool(field.Message) { + p.P(`if oneof, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent, `); ok {`) + p.P(`oneof.`, fieldName, `.ReturnToVTPool()`) + p.P(`}`) + } } else { switch field.Desc.Kind() { case protoreflect.MessageKind, protoreflect.GroupKind: diff --git a/testproto/pool/pool_with_oneof.pb.go b/testproto/pool/pool_with_oneof.pb.go index 7f88c9e..615be67 100644 --- a/testproto/pool/pool_with_oneof.pb.go +++ b/testproto/pool/pool_with_oneof.pb.go @@ -31,6 +31,7 @@ type OneofTest struct { // *OneofTest_Test1_ // *OneofTest_Test2_ // *OneofTest_Test3_ + // *OneofTest_Test4 Test isOneofTest_Test `protobuf_oneof:"test"` } @@ -94,6 +95,13 @@ func (x *OneofTest) GetTest3() *OneofTest_Test3 { return nil } +func (x *OneofTest) GetTest4() []byte { + if x, ok := x.GetTest().(*OneofTest_Test4); ok { + return x.Test4 + } + return nil +} + type isOneofTest_Test interface { isOneofTest_Test() } @@ -110,12 +118,18 @@ type OneofTest_Test3_ struct { Test3 *OneofTest_Test3 `protobuf:"bytes,3,opt,name=test3,proto3,oneof"` } +type OneofTest_Test4 struct { + Test4 []byte `protobuf:"bytes,4,opt,name=test4,proto3,oneof"` +} + func (*OneofTest_Test1_) isOneofTest_Test() {} func (*OneofTest_Test2_) isOneofTest_Test() {} func (*OneofTest_Test3_) isOneofTest_Test() {} +func (*OneofTest_Test4) isOneofTest_Test() {} + type OneofTest_Test1 struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -312,7 +326,7 @@ var file_pool_pool_with_oneof_proto_rawDesc = []byte{ 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x65, 0x74, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x65, 0x78, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x22, 0xa9, 0x02, 0x0a, 0x09, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x54, 0x65, 0x73, 0x74, 0x12, + 0x6f, 0x22, 0xc1, 0x02, 0x0a, 0x09, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x54, 0x65, 0x73, 0x74, 0x12, 0x28, 0x0a, 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x54, 0x65, 0x73, 0x74, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x31, 0x48, 0x00, 0x52, 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x12, 0x28, 0x0a, 0x05, 0x74, 0x65, 0x73, @@ -320,19 +334,20 @@ var file_pool_pool_with_oneof_proto_rawDesc = []byte{ 0x54, 0x65, 0x73, 0x74, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x32, 0x48, 0x00, 0x52, 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x12, 0x28, 0x0a, 0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x54, 0x65, 0x73, 0x74, 0x2e, 0x54, - 0x65, 0x73, 0x74, 0x33, 0x48, 0x00, 0x52, 0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x1a, 0x1b, 0x0a, - 0x05, 0x54, 0x65, 0x73, 0x74, 0x31, 0x12, 0x0c, 0x0a, 0x01, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x01, 0x61, 0x3a, 0x04, 0xa8, 0xa6, 0x1f, 0x01, 0x1a, 0x1b, 0x0a, 0x05, 0x54, 0x65, - 0x73, 0x74, 0x32, 0x12, 0x0c, 0x0a, 0x01, 0x62, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x01, - 0x62, 0x3a, 0x04, 0xa8, 0xa6, 0x1f, 0x01, 0x1a, 0x56, 0x0a, 0x05, 0x54, 0x65, 0x73, 0x74, 0x33, - 0x12, 0x27, 0x0a, 0x01, 0x63, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x4f, 0x6e, - 0x65, 0x6f, 0x66, 0x54, 0x65, 0x73, 0x74, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x33, 0x2e, 0x45, 0x6c, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x32, 0x52, 0x01, 0x63, 0x1a, 0x1e, 0x0a, 0x08, 0x45, 0x6c, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x32, 0x12, 0x0c, 0x0a, 0x01, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x01, 0x64, 0x3a, 0x04, 0xa8, 0xa6, 0x1f, 0x01, 0x3a, 0x04, 0xa8, 0xa6, 0x1f, 0x01, 0x3a, - 0x04, 0xa8, 0xa6, 0x1f, 0x01, 0x42, 0x06, 0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x42, 0x10, 0x5a, - 0x0e, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6f, 0x6f, 0x6c, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x74, 0x33, 0x48, 0x00, 0x52, 0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x12, 0x16, 0x0a, + 0x05, 0x74, 0x65, 0x73, 0x74, 0x34, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x05, + 0x74, 0x65, 0x73, 0x74, 0x34, 0x1a, 0x1b, 0x0a, 0x05, 0x54, 0x65, 0x73, 0x74, 0x31, 0x12, 0x0c, + 0x0a, 0x01, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x01, 0x61, 0x3a, 0x04, 0xa8, 0xa6, + 0x1f, 0x01, 0x1a, 0x1b, 0x0a, 0x05, 0x54, 0x65, 0x73, 0x74, 0x32, 0x12, 0x0c, 0x0a, 0x01, 0x62, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x01, 0x62, 0x3a, 0x04, 0xa8, 0xa6, 0x1f, 0x01, 0x1a, + 0x56, 0x0a, 0x05, 0x54, 0x65, 0x73, 0x74, 0x33, 0x12, 0x27, 0x0a, 0x01, 0x63, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x54, 0x65, 0x73, 0x74, 0x2e, + 0x54, 0x65, 0x73, 0x74, 0x33, 0x2e, 0x45, 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x32, 0x52, 0x01, + 0x63, 0x1a, 0x1e, 0x0a, 0x08, 0x45, 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x32, 0x12, 0x0c, 0x0a, + 0x01, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x01, 0x64, 0x3a, 0x04, 0xa8, 0xa6, 0x1f, + 0x01, 0x3a, 0x04, 0xa8, 0xa6, 0x1f, 0x01, 0x3a, 0x04, 0xa8, 0xa6, 0x1f, 0x01, 0x42, 0x06, 0x0a, + 0x04, 0x74, 0x65, 0x73, 0x74, 0x42, 0x10, 0x5a, 0x0e, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x2f, 0x70, 0x6f, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -438,6 +453,7 @@ func file_pool_pool_with_oneof_proto_init() { (*OneofTest_Test1_)(nil), (*OneofTest_Test2_)(nil), (*OneofTest_Test3_)(nil), + (*OneofTest_Test4)(nil), } type x struct{} out := protoimpl.TypeBuilder{ diff --git a/testproto/pool/pool_with_oneof.proto b/testproto/pool/pool_with_oneof.proto index 0fe83ae..178fc28 100644 --- a/testproto/pool/pool_with_oneof.proto +++ b/testproto/pool/pool_with_oneof.proto @@ -31,5 +31,6 @@ message OneofTest { Test1 test1 = 1; Test2 test2 = 2; Test3 test3 = 3; + bytes test4 = 4; } } diff --git a/testproto/pool/pool_with_oneof_vtproto.pb.go b/testproto/pool/pool_with_oneof_vtproto.pb.go index dc5a1db..049dc2a 100644 --- a/testproto/pool/pool_with_oneof_vtproto.pb.go +++ b/testproto/pool/pool_with_oneof_vtproto.pb.go @@ -138,6 +138,19 @@ func (m *OneofTest_Test3_) CloneVT() isOneofTest_Test { return r } +func (m *OneofTest_Test4) CloneVT() isOneofTest_Test { + if m == nil { + return (*OneofTest_Test4)(nil) + } + r := new(OneofTest_Test4) + if rhs := m.Test4; rhs != nil { + tmpBytes := make([]byte, len(rhs)) + copy(tmpBytes, rhs) + r.Test4 = tmpBytes + } + return r +} + func (this *OneofTest_Test1) EqualVT(that *OneofTest_Test1) bool { if this == that { return true @@ -321,6 +334,23 @@ func (this *OneofTest_Test3_) EqualVT(thatIface isOneofTest_Test) bool { return true } +func (this *OneofTest_Test4) EqualVT(thatIface isOneofTest_Test) bool { + that, ok := thatIface.(*OneofTest_Test4) + if !ok { + return false + } + if this == that { + return true + } + if this == nil && that != nil || this != nil && that == nil { + return false + } + if string(this.Test4) != string(that.Test4) { + return false + } + return true +} + func (m *OneofTest_Test1) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil @@ -581,6 +611,20 @@ func (m *OneofTest_Test3_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { } return len(dAtA) - i, nil } +func (m *OneofTest_Test4) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *OneofTest_Test4) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + i := len(dAtA) + i -= len(m.Test4) + copy(dAtA[i:], m.Test4) + i = encodeVarint(dAtA, i, uint64(len(m.Test4))) + i-- + dAtA[i] = 0x22 + return len(dAtA) - i, nil +} func (m *OneofTest_Test1) MarshalVTStrict() (dAtA []byte, err error) { if m == nil { return nil, nil @@ -772,6 +816,13 @@ func (m *OneofTest) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if msg, ok := m.Test.(*OneofTest_Test4); ok { + size, err := msg.MarshalToSizedBufferVTStrict(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + } if msg, ok := m.Test.(*OneofTest_Test3_); ok { size, err := msg.MarshalToSizedBufferVTStrict(dAtA[:i]) if err != nil { @@ -853,6 +904,20 @@ func (m *OneofTest_Test3_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error } return len(dAtA) - i, nil } +func (m *OneofTest_Test4) MarshalToVTStrict(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVTStrict(dAtA[:size]) +} + +func (m *OneofTest_Test4) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error) { + i := len(dAtA) + i -= len(m.Test4) + copy(dAtA[i:], m.Test4) + i = encodeVarint(dAtA, i, uint64(len(m.Test4))) + i-- + dAtA[i] = 0x22 + return len(dAtA) - i, nil +} var vtprotoPool_OneofTest_Test1 = sync.Pool{ New: func() interface{} { @@ -1075,6 +1140,16 @@ func (m *OneofTest_Test3_) SizeVT() (n int) { } return n } +func (m *OneofTest_Test4) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Test4) + n += 1 + l + sov(uint64(l)) + return n +} func (m *OneofTest_Test1) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 @@ -1537,6 +1612,39 @@ func (m *OneofTest) UnmarshalVT(dAtA []byte) error { m.Test = &OneofTest_Test3_{Test3: v} } iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Test4", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + v := make([]byte, postIndex-iNdEx) + copy(v, dAtA[iNdEx:postIndex]) + m.Test = &OneofTest_Test4{Test4: v} + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skip(dAtA[iNdEx:]) @@ -2021,6 +2129,38 @@ func (m *OneofTest) UnmarshalVTUnsafe(dAtA []byte) error { m.Test = &OneofTest_Test3_{Test3: v} } iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Test4", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + v := dAtA[iNdEx:postIndex] + m.Test = &OneofTest_Test4{Test4: v} + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skip(dAtA[iNdEx:])