Skip to content

Commit

Permalink
update test cases; update error message.
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Jan 31, 2025
1 parent 0e63c74 commit 7a0b7f3
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 55 deletions.
94 changes: 77 additions & 17 deletions bson/bson_binary_vector_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type bsonBinaryVectorTestCase struct {
CanonicalBson string `json:"canonical_bson"`
}

func Test_BsonBinaryVector(t *testing.T) {
func TestBsonBinaryVector(t *testing.T) {
t.Parallel()

jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir)
Expand Down Expand Up @@ -70,13 +70,13 @@ func Test_BsonBinaryVector(t *testing.T) {
val := Binary{Subtype: TypeBinaryVector}

for _, tc := range [][]byte{
{byte(Float32Vector), 0, 42},
{byte(Float32Vector), 0, 42, 42},
{byte(Float32Vector), 0, 42, 42, 42},
{Float32Vector, 0, 42},
{Float32Vector, 0, 42, 42},
{Float32Vector, 0, 42, 42, 42},

{byte(Float32Vector), 0, 42, 42, 42, 42, 42},
{byte(Float32Vector), 0, 42, 42, 42, 42, 42, 42},
{byte(Float32Vector), 0, 42, 42, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42},
} {
t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) {
val.Data = tc
Expand All @@ -91,6 +91,36 @@ func Test_BsonBinaryVector(t *testing.T) {
}
})

t.Run("FLOAT32 with padding", func(t *testing.T) {
t.Parallel()

t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Float32Vector, 3}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
})
})

t.Run("INT8 with padding", func(t *testing.T) {
t.Parallel()

t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Int8Vector, 3}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
})
})

t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) {
t.Parallel()

Expand All @@ -99,7 +129,7 @@ func Test_BsonBinaryVector(t *testing.T) {
require.EqualError(t, err, errNonZeroVectorPadding.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(PackedBitVector), 1}}}}
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Expand All @@ -118,7 +148,7 @@ func Test_BsonBinaryVector(t *testing.T) {
require.EqualError(t, err, errVectorPaddingTooLarge.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(PackedBitVector), 8}}}}
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Expand All @@ -134,13 +164,13 @@ func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
v := make([]T, len(s))
for i, e := range s {
f := math.NaN()
switch v := e.(type) {
switch val := e.(type) {
case float64:
f = v
f = val
case string:
if v == "inf" {
if val == "inf" {
f = math.Inf(0)
} else if v == "-inf" {
} else if val == "-inf" {
f = math.Inf(-1)
}
}
Expand All @@ -150,10 +180,6 @@ func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
}

func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) {
if !test.Valid {
t.Skipf("skip invalid case %s", test.Description)
}

testVector := make(map[string]Vector)
switch alias := test.DtypeHex; alias {
case "0x03":
Expand All @@ -180,6 +206,23 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector
require.NoError(t, err, "decoding canonical BSON")

t.Run("Unmarshaling", func(t *testing.T) {
skipCases := map[string]string{
"FLOAT32 with padding": "run in alternative case",
"Overflow Vector INT8": "compile-time restriction",
"Underflow Vector INT8": "compile-time restriction",
"INT8 with padding": "run in alternative case",
"INT8 with float inputs": "compile-time restriction",
"Overflow Vector PACKED_BIT": "compile-time restriction",
"Underflow Vector PACKED_BIT": "compile-time restriction",
"Vector with float values PACKED_BIT": "compile-time restriction",
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
"Negative padding PACKED_BIT": "compile-time restriction",
}
if reason, ok := skipCases[test.Description]; ok {
t.Skipf("skip test case %s: %s", test.Description, reason)
}

t.Parallel()

var got map[string]Vector
Expand All @@ -189,6 +232,23 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector
})

t.Run("Marshaling", func(t *testing.T) {
skipCases := map[string]string{
"FLOAT32 with padding": "private padding field",
"Overflow Vector INT8": "compile-time restriction",
"Underflow Vector INT8": "compile-time restriction",
"INT8 with padding": "private padding field",
"INT8 with float inputs": "compile-time restriction",
"Overflow Vector PACKED_BIT": "compile-time restriction",
"Underflow Vector PACKED_BIT": "compile-time restriction",
"Vector with float values PACKED_BIT": "compile-time restriction",
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
"Negative padding PACKED_BIT": "compile-time restriction",
}
if reason, ok := skipCases[test.Description]; ok {
t.Skipf("skip test case %s: %s", test.Description, reason)
}

t.Parallel()

got, err := Marshal(testVector)
Expand Down
71 changes: 33 additions & 38 deletions bson/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,50 @@ import (
"math"
)

// VectorDType represents the Vector data type.
type VectorDType byte

// These constants are vector data types.
const (
Int8Vector VectorDType = 0x03
Float32Vector VectorDType = 0x27
PackedBitVector VectorDType = 0x10
Int8Vector byte = 0x03
Float32Vector byte = 0x27
PackedBitVector byte = 0x10
)

// Stringer of VectorDType
func (vt VectorDType) String() string {
switch vt {
case Int8Vector:
return "int8"
case Float32Vector:
return "float32"
case PackedBitVector:
return "packed bit"
default:
return "invalid"
}
}

// These are vector conversion errors.
var (
errInsufficientVectorData = errors.New("insufficient data")
errNonZeroVectorPadding = errors.New("padding must be 0")
errVectorPaddingTooLarge = errors.New("padding larger than 7")
errVectorPaddingTooLarge = errors.New("padding cannot be larger than 7")
)

type vectorTypeError struct {
Method string
Type VectorDType
Type byte
}

// Error implements the error interface.
func (vte vectorTypeError) Error() string {
return "Call of " + vte.Method + " on " + vte.Type.String() + " vector"
t := "invalid"
switch vte.Type {
case Int8Vector:
t = "int8"
case Float32Vector:
t = "float32"
case PackedBitVector:
t = "packed bit"
}
return fmt.Sprintf("cannot call %s, on a type %s vector", vte.Method, t)
}

// Vector represents a densely packed array of numbers / bits.
type Vector struct {
dType VectorDType
dType byte
int8Data []int8
float32Data []float32
bitData []byte
bitPadding uint8
}

// Type returns the vector type.
func (v Vector) Type() VectorDType {
func (v Vector) Type() byte {
return v.dType
}

Expand Down Expand Up @@ -123,7 +115,7 @@ func (v Vector) PackedBitOK() ([]byte, uint8, bool) {
return v.bitData, v.bitPadding, true
}

// Binary returns the BSON Binary of the Vector.
// Binary returns the BSON Binary representation of the Vector.
func (v Vector) Binary() Binary {
switch v.Type() {
case Int8Vector:
Expand All @@ -133,15 +125,17 @@ func (v Vector) Binary() Binary {
case PackedBitVector:
return binaryFromBitVector(v.PackedBit())
default:
panic("invalid Vector type")
panic(fmt.Sprintf("invalid Vector data type: %d", v.dType))
}
}

func binaryFromInt8Vector(v []int8) Binary {
data := make([]byte, 2, len(v)+2)
copy(data, []byte{byte(Int8Vector), 0})
for _, e := range v {
data = append(data, byte(e))
data := make([]byte, len(v)+2)
data[0] = Int8Vector
data[1] = 0

for i, e := range v {
data[i+2] = byte(e)
}

return Binary{
Expand All @@ -152,7 +146,8 @@ func binaryFromInt8Vector(v []int8) Binary {

func binaryFromFloat32Vector(v []float32) Binary {
data := make([]byte, 2, len(v)*4+2)
copy(data, []byte{byte(Float32Vector), 0})
data[0] = Float32Vector
data[1] = 0
var a [4]byte
for _, e := range v {
binary.LittleEndian.PutUint32(a[:], math.Float32bits(e))
Expand All @@ -166,7 +161,7 @@ func binaryFromFloat32Vector(v []float32) Binary {
}

func binaryFromBitVector(bits []byte, padding uint8) Binary {
data := []byte{byte(PackedBitVector), padding}
data := []byte{PackedBitVector, padding}
data = append(data, bits...)
return Binary{
Subtype: TypeBinaryVector,
Expand All @@ -180,12 +175,12 @@ func NewVector[T int8 | float32](data []T) Vector {
switch a := any(data).(type) {
case []int8:
v.dType = Int8Vector
v.int8Data = []int8{}
v.int8Data = append(v.int8Data, a...)
v.int8Data = make([]int8, len(data))
copy(v.int8Data, a)
case []float32:
v.dType = Float32Vector
v.float32Data = []float32{}
v.float32Data = append(v.float32Data, a...)
v.float32Data = make([]float32, len(data))
copy(v.float32Data, a)
default:
panic(fmt.Errorf("unsupported type %T", data))
}
Expand Down Expand Up @@ -217,7 +212,7 @@ func NewVectorFromBinary(b Binary) (Vector, error) {
if len(b.Data) < 2 {
return v, errInsufficientVectorData
}
switch t := b.Data[0]; VectorDType(t) {
switch t := b.Data[0]; t {
case Int8Vector:
return newInt8Vector(b.Data[1:])
case Float32Vector:
Expand Down

0 comments on commit 7a0b7f3

Please sign in to comment.