Skip to content

Commit

Permalink
stree: switch comparison function signature
Browse files Browse the repository at this point in the history
Follow the lead of newer standard library packages and use a full comparison
function (reporting int) rather than a less-than comparison (reporting bool).
Apart from this change of type signature, the API semantics are identical.
  • Loading branch information
creachadair committed Dec 21, 2023
1 parent 263e211 commit 0c7425f
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 58 deletions.
8 changes: 4 additions & 4 deletions stree/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ const benchSeed = 1471808909908695897
// Trial values of β for load-testing tree operations.
var balances = []int{0, 50, 100, 150, 200, 250, 300, 500, 800, 1000}

func intLess(a, b int) bool { return a < b }
func intCompare(a, b int) int { return a - b }

func randomTree(b *testing.B, β int) (*stree.Tree[int], []int) {
rng := rand.New(rand.NewSource(benchSeed))
values := make([]int, b.N)
for i := range values {
values[i] = rng.Intn(math.MaxInt32)
}
return stree.New(β, intLess, values...), values
return stree.New(β, intCompare, values...), values
}

func BenchmarkNew(b *testing.B) {
Expand All @@ -39,7 +39,7 @@ func BenchmarkAddRandom(b *testing.B) {
b.Run(fmt.Sprintf("β=%d", β), func(b *testing.B) {
_, values := randomTree(b, β)
b.ResetTimer()
tree := stree.New[int](β, intLess)
tree := stree.New[int](β, intCompare)
for _, v := range values {
tree.Add(v)
}
Expand All @@ -50,7 +50,7 @@ func BenchmarkAddRandom(b *testing.B) {
func BenchmarkAddOrdered(b *testing.B) {
for _, β := range balances {
b.Run(fmt.Sprintf("β=%d", β), func(b *testing.B) {
tree := stree.New[int](β, intLess)
tree := stree.New[int](β, intCompare)
for i := 1; i <= b.N; i++ {
tree.Add(i)
}
Expand Down
15 changes: 7 additions & 8 deletions stree/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@ package stree_test

import (
"fmt"
"strings"

"github.com/creachadair/mds/stree"
)

func stringLess(a, b string) bool { return a < b }

type Pair struct {
X string
V int
}

func (p Pair) Less(q Pair) bool { return p.X < q.X }
func (p Pair) Compare(q Pair) int { return strings.Compare(p.X, q.X) }

func ExampleTree_Add() {
tree := stree.New(200, stringLess)
tree := stree.New(200, strings.Compare)

fmt.Println("inserted:", tree.Add("never"))
fmt.Println("inserted:", tree.Add("say"))
Expand All @@ -31,7 +30,7 @@ func ExampleTree_Add() {

func ExampleTree_Remove() {
const key = "Aloysius"
tree := stree.New(1, stringLess)
tree := stree.New(1, strings.Compare)

fmt.Println("inserted:", tree.Add(key))
fmt.Println("removed:", tree.Remove(key))
Expand All @@ -43,7 +42,7 @@ func ExampleTree_Remove() {
}

func ExampleTree_Get() {
tree := stree.New(1, Pair.Less,
tree := stree.New(1, Pair.Compare,
Pair{X: "angel", V: 5},
Pair{X: "devil", V: 7},
Pair{X: "human", V: 13},
Expand All @@ -60,7 +59,7 @@ func ExampleTree_Get() {
}

func ExampleTree_Inorder() {
tree := stree.New(15, stringLess, "eat", "those", "bloody", "vegetables")
tree := stree.New(15, strings.Compare, "eat", "those", "bloody", "vegetables")
tree.Inorder(func(key string) bool {
fmt.Println(key)
return true
Expand All @@ -73,7 +72,7 @@ func ExampleTree_Inorder() {
}

func ExampleTree_Min() {
tree := stree.New(50, intLess, 1814, 1956, 955, 1066, 2016)
tree := stree.New(50, intCompare, 1814, 1956, 955, 1066, 2016)

fmt.Println("len:", tree.Len())
fmt.Println("min:", tree.Min())
Expand Down
4 changes: 1 addition & 3 deletions stree/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ var (
sortWords = flag.Bool("sort", false, "Sort input words before insertion")
)

func lessString(a, b string) bool { return a < b }

func sortedUnique(ws []string) []string {
out := mapset.New[string](ws...).Slice()
sort.Strings(out)
Expand All @@ -32,7 +30,7 @@ func sortedUnique(ws []string) []string {
// Construct a tree with the words from input, returning the finished tree and
// the original words as split by strings.Fields.
func makeTree(β int, input string) (*Tree[string], []string) {
tree := New(β, lessString)
tree := New(β, strings.Compare)
words := strings.Fields(input)
if *sortWords {
sort.Strings(words)
Expand Down
13 changes: 7 additions & 6 deletions stree/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,15 @@ func (n *node[T]) inorder(f func(T) bool) bool {

// pathTo returns the sequence of nodes beginning at n leading to key, if key
// is present. If key was found, its node is the last element of the path.
func (n *node[T]) pathTo(key T, lessThan func(a, b T) bool) []*node[T] {
func (n *node[T]) pathTo(key T, compare func(a, b T) int) []*node[T] {
var path []*node[T]
cur := n
for cur != nil {
path = append(path, cur)
if lessThan(key, cur.X) {
cmp := compare(key, cur.X)
if cmp < 0 {
cur = cur.left
} else if lessThan(cur.X, key) {
} else if cmp > 0 {
cur = cur.right
} else {
break
Expand All @@ -109,13 +110,13 @@ func (n *node[T]) pathTo(key T, lessThan func(a, b T) bool) []*node[T] {

// inorderAfter visits the elements of the subtree under n not less than key
// inorder, calling f for each until f returns false.
func (n *node[T]) inorderAfter(key T, lessThan func(a, b T) bool, f func(T) bool) bool {
func (n *node[T]) inorderAfter(key T, compare func(a, b T) int, f func(T) bool) bool {
// Find the path from the root to key. Any nodes greater than or equal to
// key must be on or to the right of this path.
path := n.pathTo(key, lessThan)
path := n.pathTo(key, compare)
for i := len(path) - 1; i >= 0; i-- {
cur := path[i]
if lessThan(cur.X, key) {
if compare(cur.X, key) < 0 {
continue
} else if ok := f(cur.X); !ok {
return false
Expand Down
64 changes: 35 additions & 29 deletions stree/stree.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package stree
import (
"fmt"
"math"
"sort"
"slices"
)

const (
Expand All @@ -29,11 +29,12 @@ const (
)

// New returns a new tree with the given balancing factor 0 ≤ β ≤ 1000. The
// order of elements stored in the tree is provided by the comparison function.
// order of elements stored in the tree is provided by the comparison function,
// where compare(a, b) must be <0 if a < b, =0 if a == b, and >0 if a > b.
//
// If any keys are given, the tree is initialized to contain them, otherwise an
// empty tree is created. For keys that are known in advance it is more
// efficient to allocate storage for them at construction time, versus adding
// empty tree is created. When the initial set of keys is known in advance it
// is more efficient to add them during tree construction, versus versus adding
// them separately later.
//
// The balancing factor controls how unbalanced the tree is permitted to be,
Expand All @@ -42,24 +43,24 @@ const (
// rebalancing, but provides faster lookups. A good default is 250.
//
// New panics if β < 0 or β > 1000.
func New[T any](β int, lessThan func(a, b T) bool, keys ...T) *Tree[T] {
func New[T any](β int, compare func(a, b T) int, keys ...T) *Tree[T] {
if β < 0 || β > maxBalance {
panic("β out of range")
}
tree := &Tree[T]{
β: β,
lessThan: lessThan,
limit: limitFunc(β),
size: len(keys),
max: len(keys),
β: β,
compare: compare,
limit: limitFunc(β),
size: len(keys),
max: len(keys),
}
if len(keys) != 0 {
nodes := make([]*node[T], len(keys))
for i, key := range keys {
nodes[i] = &node[T]{X: key}
}
sort.Slice(nodes, func(i, j int) bool {
return lessThan(nodes[i].X, nodes[j].X)
slices.SortFunc(nodes, func(a, b *node[T]) int {
return compare(a.X, b.X)
})
tree.root = extract(nodes)
}
Expand All @@ -77,11 +78,11 @@ type Tree[T any] struct {
// requires one floating-point operation per insertion to recompute the
// depth limit.

β int // balancing factor
lessThan func(a, b T) bool // key comparison
limit func(n int) int // depth limit for size n
size int // cache of root.size()
max int // max of size since last rebuild of root
β int // balancing factor
compare func(a, b T) int // key comparison
limit func(n int) int // depth limit for size n
size int // cache of root.size()
max int // max of size since last rebuild of root
}

func toFraction(β int) float64 { return (float64(β) + maxBalance) / fracLimit }
Expand Down Expand Up @@ -149,12 +150,14 @@ func (t *Tree[T]) insert(key T, replace bool, root *node[T], limit int) (ins *no
size = 1
}
return &node[T]{X: key}, true, size, 0
} else if t.lessThan(key, root.X) {
}
cmp := t.compare(key, root.X)
if cmp < 0 {
ins, added, size, height = t.insert(key, replace, root.left, limit-1)
root.left = ins
sib = root.right
height++
} else if t.lessThan(root.X, key) {
} else if cmp > 0 {
ins, added, size, height = t.insert(key, replace, root.right, limit-1)
root.right = ins
sib = root.left
Expand Down Expand Up @@ -193,7 +196,7 @@ func (t *Tree[T]) insert(key T, replace bool, root *node[T], limit int) (ins *no

// Remove key from the tree and report whether it was present.
func (t *Tree[T]) Remove(key T) bool {
del, ok := t.root.remove(key, t.lessThan)
del, ok := t.root.remove(key, t.compare)
t.root = del
if ok {
t.size--
Expand All @@ -207,14 +210,16 @@ func (t *Tree[T]) Remove(key T) bool {

// remove key from the subtree under n, returning the modified tree reporting
// whether the mass of the tree was decreased.
func (n *node[T]) remove(key T, lessThan func(a, b T) bool) (_ *node[T], ok bool) {
func (n *node[T]) remove(key T, compare func(a, b T) int) (_ *node[T], ok bool) {
if n == nil {
return nil, false // nothing to do
} else if lessThan(key, n.X) {
n.left, ok = n.left.remove(key, lessThan)
}
cmp := compare(key, n.X)
if cmp < 0 {
n.left, ok = n.left.remove(key, compare)
return n, ok
} else if lessThan(n.X, key) {
n.right, ok = n.right.remove(key, lessThan)
} else if cmp > 0 {
n.right, ok = n.right.remove(key, compare)
return n, ok
} else if n.left == nil {
return n.right, true
Expand Down Expand Up @@ -248,9 +253,10 @@ func (t *Tree[T]) Clear() { t.size = 0; t.max = 0; t.root = nil }
func (t *Tree[T]) Get(key T) (_ T, ok bool) {
cur := t.root
for cur != nil {
if t.lessThan(key, cur.X) {
cmp := t.compare(key, cur.X)
if cmp < 0 {
cur = cur.left
} else if t.lessThan(cur.X, key) {
} else if cmp > 0 {
cur = cur.right
} else {
return cur.X, true
Expand All @@ -268,13 +274,13 @@ func (t *Tree[T]) Inorder(f func(key T) bool) bool { return t.root.inorder(f) }
// if f returns false, InorderAfter stops and returns fales. Otherwise, it
// returns true after visiting all eligible elements of t.
func (t *Tree[T]) InorderAfter(key T, f func(key T) bool) bool {
return t.root.inorderAfter(key, t.lessThan, f)
return t.root.inorderAfter(key, t.compare, f)
}

// Cursor constructs a cursor to the specified key, or nil if key is not
// present in the tree.
func (t *Tree[T]) Cursor(key T) *Cursor[T] {
path := t.root.pathTo(key, t.lessThan)
path := t.root.pathTo(key, t.compare)
if len(path) == 0 {
return nil
}
Expand Down
16 changes: 8 additions & 8 deletions stree/stree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func sortedUnique(ws []string, drop mapset.Set[string]) []string {

func TestNew(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
tree := stree.New(100, stringLess)
tree := stree.New(100, strings.Compare)
if n := tree.Len(); n != 0 {
t.Errorf("Len of empty tree: got %v, want 0", n)
}
Expand All @@ -38,7 +38,7 @@ func TestNew(t *testing.T) {
}
})
t.Run("NonEmpty", func(t *testing.T) {
tree := stree.New(200, stringLess, "please", "fetch", "your", "slippers")
tree := stree.New(200, strings.Compare, "please", "fetch", "your", "slippers")
got := allWords(tree)
want := []string{"fetch", "please", "slippers", "your"}
if diff := cmp.Diff(got, want); diff != "" {
Expand All @@ -53,7 +53,7 @@ func TestNew(t *testing.T) {

func TestRemoval(t *testing.T) {
words := strings.Fields(`a foolish consistency is the hobgoblin of little minds`)
tree := stree.New[string](0, stringLess, words...)
tree := stree.New[string](0, strings.Compare, words...)

got := allWords(tree)
want := sortedUnique(words, nil)
Expand All @@ -80,8 +80,8 @@ func TestInsertion(t *testing.T) {
val int
}

tree := stree.New[kv](300, func(a, b kv) bool {
return a.key < b.key
tree := stree.New[kv](300, func(a, b kv) int {
return strings.Compare(a.key, b.key)
})
checkInsert := func(f func(kv) bool, key string, val int, ok bool) {
t.Helper()
Expand Down Expand Up @@ -113,7 +113,7 @@ func TestInsertion(t *testing.T) {

func TestInorderAfter(t *testing.T) {
keys := []string{"8", "6", "7", "5", "3", "0", "9"}
tree := stree.New(0, stringLess, keys...)
tree := stree.New(0, strings.Compare, keys...)
tests := []struct {
key string
want string
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestInorderAfter(t *testing.T) {

func TestCursor(t *testing.T) {
t.Run("EmptyTree", func(t *testing.T) {
tree := stree.New(250, stringLess)
tree := stree.New(250, strings.Compare)

// An empty tree reports nil cursors.
if got := tree.Cursor("whatever"); got.Valid() {
Expand All @@ -165,7 +165,7 @@ func TestCursor(t *testing.T) {
}
})

tree := stree.New(250, stringLess, "a", "b", "c", "d", "e", "f", "g")
tree := stree.New(250, strings.Compare, "a", "b", "c", "d", "e", "f", "g")
t.Run("Iterate", func(t *testing.T) {
type tcursor = *stree.Cursor[string]

Expand Down

0 comments on commit 0c7425f

Please sign in to comment.