From d5ebeca2137a4759d8262cd8faa910cb52621efa Mon Sep 17 00:00:00 2001 From: Marc Binz Date: Thu, 2 May 2024 22:02:30 +0200 Subject: [PATCH] feat: implement live count method (#234) --- core/embed/query/builder.go | 51 ++++++++++++++++++- tests/basic/gen/som/query/builder.go | 51 ++++++++++++++++++- tests/basic/som_live_test.go | 73 +++++++++++++++++++++++++++- 3 files changed, 172 insertions(+), 3 deletions(-) diff --git a/core/embed/query/builder.go b/core/embed/query/builder.go index fcc284b6..85d2f40c 100644 --- a/core/embed/query/builder.go +++ b/core/embed/query/builder.go @@ -246,11 +246,60 @@ func (b Builder[M, C]) Live(ctx context.Context) (<-chan LiveResult[*M], error) req := b.query.BuildAsLive() resChan, err := b.db.Live(ctx, req.Statement, req.Variables) if err != nil { - return nil, fmt.Errorf("could not query live records: %w", err) + return nil, fmt.Errorf("failed to query live records: %w", err) } return live(ctx, resChan, b.unmarshal, b.convTo), nil } +// LiveCount is the live version of Count. +// Whenever a record is created or deleted that matches the +// conditions of the query, the count will be updated. +func (b Builder[M, C]) LiveCount(ctx context.Context) (<-chan int, error) { + count, err := b.Count(ctx) + if err != nil { + return nil, fmt.Errorf("failed to execute initial count: %w", err) + } + + resChan, err := b.Live(ctx) + if err != nil { + return nil, fmt.Errorf("failed to execute live query: %w", err) + } + + countChan := make(chan int, 1) + + go func() { + defer close(countChan) + + for { + select { + + case <-ctx.Done(): + return + + case res, open := <-resChan: + if !open { + return + } + + switch res.(type) { + + case LiveCreate[*M]: + count++ + + case LiveDelete[*M]: + count-- + } + + countChan <- count + } + } + }() + + countChan <- count + + return countChan, nil +} + // LiveDiff behaves like Live, but instead of receiving the full result // set on every change, it only receives the actual changes. //func (b builder[M, C]) LiveDiff(ctx context.Context) (<-chan LiveResult[*M], error) { diff --git a/tests/basic/gen/som/query/builder.go b/tests/basic/gen/som/query/builder.go index 44a5f157..be0427b6 100755 --- a/tests/basic/gen/som/query/builder.go +++ b/tests/basic/gen/som/query/builder.go @@ -246,11 +246,60 @@ func (b Builder[M, C]) Live(ctx context.Context) (<-chan LiveResult[*M], error) req := b.query.BuildAsLive() resChan, err := b.db.Live(ctx, req.Statement, req.Variables) if err != nil { - return nil, fmt.Errorf("could not query live records: %w", err) + return nil, fmt.Errorf("failed to query live records: %w", err) } return live(ctx, resChan, b.unmarshal, b.convTo), nil } +// LiveCount is the live version of Count. +// Whenever a record is created or deleted that matches the +// conditions of the query, the count will be updated. +func (b Builder[M, C]) LiveCount(ctx context.Context) (<-chan int, error) { + count, err := b.Count(ctx) + if err != nil { + return nil, fmt.Errorf("failed to execute initial count: %w", err) + } + + resChan, err := b.Live(ctx) + if err != nil { + return nil, fmt.Errorf("failed to execute live query: %w", err) + } + + countChan := make(chan int, 1) + + go func() { + defer close(countChan) + + for { + select { + + case <-ctx.Done(): + return + + case res, open := <-resChan: + if !open { + return + } + + switch res.(type) { + + case LiveCreate[*M]: + count++ + + case LiveDelete[*M]: + count-- + } + + countChan <- count + } + } + }() + + countChan <- count + + return countChan, nil +} + // LiveDiff behaves like Live, but instead of receiving the full result // set on every change, it only receives the actual changes. //func (b builder[M, C]) LiveDiff(ctx context.Context) (<-chan LiveResult[*M], error) { diff --git a/tests/basic/som_live_test.go b/tests/basic/som_live_test.go index d3a69b93..e3b52284 100644 --- a/tests/basic/som_live_test.go +++ b/tests/basic/som_live_test.go @@ -7,6 +7,7 @@ import ( "github.com/go-surreal/som/tests/basic/model" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" + "math/rand" "testing" "time" ) @@ -146,7 +147,7 @@ func TestLiveQueries(t *testing.T) { select { - case _, more = <-liveChan: + case _, more := <-liveChan: if more { t.Fatal("liveChan did not close after context was canceled") } @@ -231,3 +232,73 @@ func TestLiveQueriesFilter(t *testing.T) { } } } + +func TestLiveQueryCount(t *testing.T) { + ctx := context.Background() + + client, cleanup := prepareDatabase(ctx, t) + defer cleanup() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + if err := client.ApplySchema(ctx); err != nil { + t.Fatal(err) + } + + liveCount, err := client.AllFieldTypesRepo().Query().LiveCount(ctx) + if err != nil { + t.Fatal(err) + } + + count := rand.Intn(randMax-randMin) + randMin + + var models []*model.AllFieldTypes + + for i := 0; i < count; i++ { + newModel := &model.AllFieldTypes{} + + if err := client.AllFieldTypesRepo().Create(ctx, newModel); err != nil { + t.Fatal(err) + } + + models = append(models, newModel) + } + + for i := 0; i <= count; i++ { + assert.Equal(t, i, <-liveCount) + } + + for _, delModel := range models { + if err := client.AllFieldTypesRepo().Delete(ctx, delModel); err != nil { + t.Fatal(err) + } + } + + for i := count; i > 0; i-- { + assert.Equal(t, i-1, <-liveCount) + } + + select { + + case <-liveCount: + t.Fatal("liveCount should not receive any more messages") + + case <-time.After(1 * time.Second): + } + + // Test the automatic closing of the live channel when the context is canceled: + + cancel() + + select { + + case _, more := <-liveCount: + if more { + t.Fatal("liveCount did not close after context was canceled") + } + + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for live channel to close after context was canceled") + } +}