From 4ed8517060040a4beeb0cf71cdd8696a212664ed Mon Sep 17 00:00:00 2001 From: Martin Sucha Date: Wed, 3 Jul 2024 17:30:54 +0200 Subject: [PATCH] Expose the murmur3 partitioning function directly I need to compute it in external code. --- session.go | 37 +++++++++++++++++ session_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) diff --git a/session.go b/session.go index 6159c4441..6961599c7 100644 --- a/session.go +++ b/session.go @@ -19,6 +19,7 @@ import ( "unicode" "github.com/gocql/gocql/internal/lru" + "github.com/gocql/gocql/internal/murmur" ) // Session is the interface used by users to interact with the database. @@ -2122,6 +2123,42 @@ func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]b return routingKey, nil } +var identityIndexes = []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + +// Murmur3Token computes a murmur3 token for the given values. +func Murmur3Token(types []TypeInfo, values []interface{}) (int64, error) { + if len(values) == 0 { + return 0, fmt.Errorf("gocql: no values provided") + } + + if len(types) != len(values) { + return 0, fmt.Errorf("gocql: types and values length mismatch") + } + + var indexes []int + if n := len(types); n <= len(identityIndexes) { + indexes = identityIndexes[:n] + } else { + indexes = make([]int, n) + for i := range indexes { + indexes[i] = i + } + } + + rki := &routingKeyInfo{ + types: types, + indexes: indexes, + names: make([]string, len(types)), // used in error messages + } + + routingKey, err := createRoutingKey(rki, values) + if err != nil { + return 0, err + } + + return murmur.Murmur3H1(routingKey), nil +} + func (b *Batch) borrowForExecution() { // empty, because Batch has no equivalent of Query.Release() // that would race with speculative executions. diff --git a/session_test.go b/session_test.go index d5a59dac4..0574e3cc1 100644 --- a/session_test.go +++ b/session_test.go @@ -333,3 +333,108 @@ func TestIsUseStatement(t *testing.T) { } } } + +func TestMurmur3Token(t *testing.T) { + tests := map[string]struct { + types []TypeInfo + values []interface{} + expected int64 + err string + }{ + "8351-2882-581-20036": { + types: []TypeInfo{ + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + }, + values: []interface{}{ + int16(8351), + int16(2882), + int16(581), + int16(20036), + }, + expected: -9223371846324820981, + }, + "3852-744-522-20116": { + types: []TypeInfo{ + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + }, + values: []interface{}{ + int16(3852), + int16(744), + int16(522), + int16(20116), + }, + expected: -9223370757649630452, + }, + "5212-813-19933-0": { + types: []TypeInfo{ + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + }, + values: []interface{}{ + int16(5212), + int16(813), + int16(19933), + int16(0), + }, + expected: 5363655924167674765, + }, + "empty": { + types: []TypeInfo{ + NewNativeType(protoVersion4, TypeSmallInt, ""), + }, + values: nil, + err: "gocql: no values provided", + }, + "mismatched length": { + types: []TypeInfo{ + NewNativeType(protoVersion4, TypeSmallInt, ""), + }, + values: []interface{}{ + int16(5212), + int16(813), + int16(19933), + int16(0), + }, + err: "gocql: types and values length mismatch", + }, + "invalid value": { + types: []TypeInfo{ + NewNativeType(protoVersion4, TypeSmallInt, ""), + NewNativeType(protoVersion4, TypeSmallInt, ""), + }, + values: []interface{}{ + int16(5212), + "hello there", + }, + err: "can not marshal string into smallint: strconv.ParseInt: parsing \"hello there\": invalid syntax", + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + got, err := Murmur3Token(test.types, test.values) + if test.err != "" { + if err == nil { + t.Fatalf("expected error %q, got nil", test.err) + } + if gotErr := err.Error(); gotErr != test.err { + t.Fatalf("expected error %q, got %q", test.err, gotErr) + } + } else { + if err != nil { + t.Fatalf("unexpected error %q", err) + } + if got != test.expected { + t.Fatalf("expected %v, got %v", test.expected, got) + } + } + }) + } +}