From 2356a006cc38d7d77c865917fb2037350e05897e Mon Sep 17 00:00:00 2001 From: Oliver Muir Date: Wed, 15 Nov 2023 10:51:07 +0000 Subject: [PATCH] Fix JSONSchema unmarshalling in TableView --- pulsar/table_view.go | 6 +- pulsar/table_view_impl.go | 11 +- pulsar/table_view_test.go | 219 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 227 insertions(+), 9 deletions(-) diff --git a/pulsar/table_view.go b/pulsar/table_view.go index e566bf0bec..58a664aed7 100644 --- a/pulsar/table_view.go +++ b/pulsar/table_view.go @@ -65,12 +65,12 @@ type TableView interface { // Keys returns a slice of the keys contained in this TableView. Keys() []string - // ForEach performs the give action for each entry in this map until all entries have been processed or the action + // ForEach performs the given action for each entry in this map until all entries have been processed or the action // returns an error. ForEach(func(string, interface{}) error) error - // ForEachAndListen performs the give action for each entry in this map until all entries have been processed or - // the action returns an error. + // ForEachAndListen performs the given action for each entry in this map until all entries have been processed or + // the action returns an error. The given action will then be performed on each new entry in this map. ForEachAndListen(func(string, interface{}) error) error // Close closes the table view and releases resources allocated. diff --git a/pulsar/table_view_impl.go b/pulsar/table_view_impl.go index 47f8c6c01f..17e0b90f3b 100644 --- a/pulsar/table_view_impl.go +++ b/pulsar/table_view_impl.go @@ -245,19 +245,18 @@ func (tv *TableViewImpl) handleMessage(msg Message) { tv.dataMu.Lock() defer tv.dataMu.Unlock() - var payload interface{} + payload := reflect.New(tv.options.SchemaValueType) if len(msg.Payload()) == 0 { delete(tv.data, msg.Key()) } else { - payload = reflect.Indirect(reflect.New(tv.options.SchemaValueType)).Interface() - if err := msg.GetSchemaValue(&payload); err != nil { - tv.logger.Errorf("msg.GetSchemaValue() failed with %w; msg is %v", err, msg) + if err := msg.GetSchemaValue(payload.Interface()); err != nil { + tv.logger.Errorf("msg.GetSchemaValue() failed with %v; msg is %v", err, msg) } - tv.data[msg.Key()] = payload + tv.data[msg.Key()] = reflect.Indirect(payload).Interface() } for _, listener := range tv.listeners { - if err := listener(msg.Key(), payload); err != nil { + if err := listener(msg.Key(), reflect.Indirect(payload).Interface()); err != nil { tv.logger.Errorf("table view listener failed for %v: %w", msg, err) } } diff --git a/pulsar/table_view_test.go b/pulsar/table_view_test.go index d29b24d298..fd4decae64 100644 --- a/pulsar/table_view_test.go +++ b/pulsar/table_view_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + pb "github.com/apache/pulsar-client-go/integration-tests/pb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -80,6 +81,165 @@ func TestTableView(t *testing.T) { } } +func TestTableViewSchemas(t *testing.T) { + var tests = []struct { + name string + schema Schema + schemaType interface{} + producerValue interface{} + expValueOut interface{} + valueCheck func(t *testing.T, got any) // Overrides expValueOut for more complex checks + }{ + { + name: "StringSchema", + schema: NewStringSchema(nil), + schemaType: pointer("hello pulsar"), + producerValue: "hello pulsar", + expValueOut: pointer("hello pulsar"), + }, + { + name: "JSONSchema", + schema: NewJSONSchema(exampleSchemaDef, nil), + schemaType: testJSON{}, + producerValue: testJSON{ID: 1, Name: "Pulsar"}, + expValueOut: testJSON{ID: 1, Name: "Pulsar"}, + }, + { + name: "JSONSchema pointer type", + schema: NewJSONSchema(exampleSchemaDef, nil), + schemaType: pointer(testJSON{ID: 1, Name: "Pulsar"}), + producerValue: testJSON{ID: 1, Name: "Pulsar"}, + expValueOut: pointer(testJSON{ID: 1, Name: "Pulsar"}), + }, + { + name: "AvroSchema", + schema: NewAvroSchema(exampleSchemaDef, nil), + schemaType: testAvro{ID: 1, Name: "Pulsar"}, + producerValue: testAvro{ID: 1, Name: "Pulsar"}, + expValueOut: testAvro{ID: 1, Name: "Pulsar"}, + }, + { + name: "Int8Schema", + schema: NewInt8Schema(nil), + schemaType: int8(0), + producerValue: int8(1), + expValueOut: int8(1), + }, + { + name: "Int16Schema", + schema: NewInt16Schema(nil), + schemaType: int16(0), + producerValue: int16(1), + expValueOut: int16(1), + }, + { + name: "Int32Schema", + schema: NewInt32Schema(nil), + schemaType: int32(0), + producerValue: int32(1), + expValueOut: int32(1), + }, + { + name: "Int64Schema", + schema: NewInt64Schema(nil), + schemaType: int64(0), + producerValue: int64(1), + expValueOut: int64(1), + }, + { + name: "FloatSchema", + schema: NewFloatSchema(nil), + schemaType: float32(0), + producerValue: float32(1), + expValueOut: float32(1), + }, + { + name: "DoubleSchema", + schema: NewDoubleSchema(nil), + schemaType: float64(0), + producerValue: float64(1), + expValueOut: float64(1), + }, + { + name: "ProtoSchema", + schema: NewProtoSchema(protoSchemaDef, nil), + schemaType: pb.Test{}, + producerValue: &pb.Test{Num: 1, Msf: "Pulsar"}, + valueCheck: func(t *testing.T, got any) { + assert.IsType(t, pb.Test{}, got) + + pbt, ok := got.(pb.Test) + if assert.Truef(t, ok, "expected type pb.Test got %T", got) { + assert.Equal(t, int32(1), pbt.Num) + assert.Equal(t, "Pulsar", pbt.Msf) + } + }, + }, + { + name: "ProtoNativeSchema", + schema: NewProtoNativeSchemaWithMessage(&pb.Test{}, nil), + schemaType: pb.Test{}, + producerValue: &pb.Test{Num: 1, Msf: "Pulsar"}, + valueCheck: func(t *testing.T, got any) { + assert.IsType(t, pb.Test{}, got) + + pbt, ok := got.(pb.Test) + if assert.Truef(t, ok, "expected type pb.Test got %T", got) { + assert.Equal(t, int32(1), pbt.Num) + assert.Equal(t, "Pulsar", pbt.Msf) + } + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.NoError(t, err) + defer client.Close() + + topic := newTopicName() + + // create producer + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + Schema: test.schema, + }) + assert.NoError(t, err) + defer producer.Close() + + _, err = producer.Send(context.Background(), &ProducerMessage{ + Key: "testKey", + Value: test.producerValue, + }) + assert.NoError(t, err) + + // create table view + tv, err := client.CreateTableView(TableViewOptions{ + Topic: topic, + Schema: test.schema, + SchemaValueType: reflect.TypeOf(test.schemaType), + }) + assert.NoError(t, err) + defer tv.Close() + + value := tv.Get("testKey") + if test.valueCheck != nil { + test.valueCheck(t, value) + } else { + assert.IsType(t, test.expValueOut, value) + assert.Equal(t, test.expValueOut, value) + } + }) + } +} + +func pointer[T any](v T) *T { + return &v +} + func TestPublishNilValue(t *testing.T) { client, err := NewClient(ClientOptions{ URL: lookupURL, @@ -143,3 +303,62 @@ func TestPublishNilValue(t *testing.T) { assert.Equal(t, *(tv.Get("key-2").(*string)), "value-2") } + +func TestForEachAndListenJSONSchema(t *testing.T) { + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.NoError(t, err) + defer client.Close() + + topic := newTopicName() + schema := NewJSONSchema(exampleSchemaDef, nil) + + // create table view + tv, err := client.CreateTableView(TableViewOptions{ + Topic: topic, + Schema: schema, + SchemaValueType: reflect.TypeOf(testJSON{}), + }) + assert.NoError(t, err) + defer tv.Close() + + // create listener + valuePrefix := "hello pulsar: " + tv.ForEachAndListen(func(key string, value interface{}) error { + t.Log("foreach" + key) + s, ok := value.(testJSON) + assert.Truef(t, ok, "expected value to be testJSON type got %T", value) + assert.Equal(t, fmt.Sprintf(valuePrefix+key), s.Name) + return nil + }) + + // create producer + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + Schema: schema, + }) + assert.NoError(t, err) + defer producer.Close() + + numMsg := 10 + for i := 0; i < numMsg; i++ { + key := fmt.Sprintf("%d", i) + t.Log("producing" + key) + _, err = producer.Send(context.Background(), &ProducerMessage{ + Key: key, + Value: testJSON{ + ID: i, + Name: fmt.Sprintf(valuePrefix + key), + }, + }) + assert.NoError(t, err) + } + + // Wait until tv receives all messages + for tv.Size() < 10 { + time.Sleep(time.Second * 1) + t.Logf("TableView number of elements: %d", tv.Size()) + } +}