Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
Signed-off-by: Joshua Kim <[email protected]>
  • Loading branch information
joshua-kim committed Jan 10, 2025
1 parent effb777 commit ff62e92
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 67 deletions.
43 changes: 22 additions & 21 deletions chain/bond.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
package chain

import (
"context"
"errors"
"fmt"

"github.com/ava-labs/avalanchego/database"
"github.com/ava-labs/avalanchego/utils/wrappers"

"github.com/ava-labs/hypersdk/codec"
"github.com/ava-labs/hypersdk/state"
"github.com/ava-labs/hypersdk/x/fdsmr"
)

Expand All @@ -27,41 +29,40 @@ type BondBalance struct {
Max uint32 `serialize:"true"`
}

func NewBonder(db database.Database) *Bonder {
return &Bonder{db: db}
}

// Bonder maintains state of account bond balances to limit the amount of
// pending transactions per account
type Bonder struct {
db database.Database
}
type Bonder struct{}

// this needs to be thread-safe if it's called from the api
func (b *Bonder) SetMaxBalance(address codec.Address, maxBalance uint32) error {
func (b *Bonder) SetMaxBalance(

Check failure on line 37 in chain/bond.go

View workflow job for this annotation

GitHub Actions / hypersdk-lint

unused-receiver: method receiver 'b' is not referenced in method's body, consider removing or renaming it as _ (revive)
ctx context.Context,
mutable state.Mutable,
address codec.Address,
maxBalance uint32,
) error {
addressBytes := address[:]
if maxBalance == 0 {
return b.db.Delete(addressBytes)
return mutable.Remove(ctx, addressBytes)
}

balance, err := b.getBalance(addressBytes)
balance, err := getBalance(ctx, mutable, addressBytes)
if err != nil {
return err
}

balance.Max = maxBalance
if err := b.putBalance(addressBytes, balance); err != nil {
if err := putBalance(ctx, mutable, addressBytes, balance); err != nil {
return err
}

return nil
}

func (b *Bonder) Bond(tx *Transaction) (bool, error) {
func (b *Bonder) Bond(ctx context.Context, mutable state.Mutable, tx *Transaction) (bool, error) {

Check failure on line 61 in chain/bond.go

View workflow job for this annotation

GitHub Actions / hypersdk-lint

unused-receiver: method receiver 'b' is not referenced in method's body, consider removing or renaming it as _ (revive)
address := tx.GetSponsor()
addressBytes := address[:]

balance, err := b.getBalance(addressBytes)
balance, err := getBalance(ctx, mutable, addressBytes)
if err != nil {
return false, err
}
Expand All @@ -71,18 +72,18 @@ func (b *Bonder) Bond(tx *Transaction) (bool, error) {
}

balance.Pending++
if err := b.putBalance(addressBytes, balance); err != nil {
if err := putBalance(ctx, mutable, addressBytes, balance); err != nil {
return false, err
}

return true, nil
}

func (b *Bonder) Unbond(tx *Transaction) error {
func (b *Bonder) Unbond(ctx context.Context, mutable state.Mutable, tx *Transaction) error {

Check failure on line 82 in chain/bond.go

View workflow job for this annotation

GitHub Actions / hypersdk-lint

unused-receiver: method receiver 'b' is not referenced in method's body, consider removing or renaming it as _ (revive)
address := tx.GetSponsor()
addressBytes := address[:]

balance, err := b.getBalance(addressBytes)
balance, err := getBalance(ctx, mutable, addressBytes)
if err != nil {
return err
}
Expand All @@ -92,15 +93,15 @@ func (b *Bonder) Unbond(tx *Transaction) error {
}

balance.Pending--
if err := b.putBalance(addressBytes, balance); err != nil {
if err := putBalance(ctx, mutable, addressBytes, balance); err != nil {
return err
}

return nil
}

func (b *Bonder) getBalance(address []byte) (BondBalance, error) {
currentBytes, err := b.db.Get(address)
func getBalance(ctx context.Context, mutable state.Mutable, address []byte) (BondBalance, error) {
currentBytes, err := mutable.GetValue(ctx, address)
if err != nil && !errors.Is(err, database.ErrNotFound) {
return BondBalance{}, fmt.Errorf("failed to get bond balance")

Check failure on line 106 in chain/bond.go

View workflow job for this annotation

GitHub Actions / hypersdk-lint

string-format: no format directive, use errors.New instead (revive)

Check failure on line 106 in chain/bond.go

View workflow job for this annotation

GitHub Actions / hypersdk-lint

fmt.Errorf can be replaced with errors.New (perfsprint)
}
Expand All @@ -120,13 +121,13 @@ func (b *Bonder) getBalance(address []byte) (BondBalance, error) {
return balance, nil
}

func (b *Bonder) putBalance(address []byte, balance BondBalance) error {
func putBalance(ctx context.Context, mutable state.Mutable, address []byte, balance BondBalance) error {
p := &wrappers.Packer{Bytes: make([]byte, bondAllocSize)}
if err := codec.LinearCodec.MarshalInto(balance, p); err != nil {
return fmt.Errorf("failed to marshal bond balance: %w", err)
}

if err := b.db.Put(address, p.Bytes); err != nil {
if err := mutable.Insert(ctx, address, p.Bytes); err != nil {
return fmt.Errorf("failed to update bond balance: %w", err)
}

Expand Down
70 changes: 41 additions & 29 deletions chain/bond_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"context"
"testing"

"github.com/ava-labs/avalanchego/database/memdb"
"github.com/ava-labs/hypersdk/state"
"github.com/stretchr/testify/require"

"github.com/ava-labs/hypersdk/codec"
Expand All @@ -30,26 +30,27 @@ func TestSetMaxBalance(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := require.New(t)
b := NewBonder(memdb.New())
b := Bonder{}

mutable := &state.SimpleMutable{}
address := codec.Address{1, 2, 3}
r.NoError(b.SetMaxBalance(address, tt.maxBalance))
r.NoError(b.SetMaxBalance(context.Background(), mutable, address, tt.maxBalance))

for i := 0; i < int(tt.maxBalance); i++ {
ok, err := b.Bond(&Transaction{
Auth: TestAuth{
SponsorF: address,
},
})
ok, err := b.Bond(
context.Background(),
mutable,
&Transaction{Auth: TestAuth{SponsorF: address}},
)
r.NoError(err)
r.True(ok)
}

ok, err := b.Bond(&Transaction{
Auth: TestAuth{
SponsorF: address,
},
})
ok, err := b.Bond(
context.Background(),
mutable,
&Transaction{Auth: TestAuth{SponsorF: address}},
)
r.NoError(err)
r.False(ok)
})
Expand Down Expand Up @@ -139,24 +140,34 @@ func TestBond(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := require.New(t)
b := NewBonder(memdb.New())
b := Bonder{}

mutable := &state.SimpleMutable{}
address := codec.Address{1, 2, 3}
r.NoError(b.SetMaxBalance(address, tt.max))
r.NoError(b.SetMaxBalance(
context.Background(),
mutable,
address,
tt.max,
))

for _, wantOk := range tt.wantBond {
ok, err := b.Bond(&Transaction{
Auth: TestAuth{
SponsorF: address,
},
})
ok, err := b.Bond(
context.Background(),
mutable,
&Transaction{Auth: TestAuth{SponsorF: address}},
)
r.NoError(err)
r.Equal(wantOk, ok)
}

for _, wantErr := range tt.wantUnbond {
r.ErrorIs(
b.Unbond(&Transaction{Auth: TestAuth{SponsorF: address}}),
b.Unbond(
context.Background(),
mutable,
&Transaction{Auth: TestAuth{SponsorF: address}},
),
wantErr,
)
}
Expand All @@ -166,10 +177,11 @@ func TestBond(t *testing.T) {

func TestSetMaxBalanceDuringBond(t *testing.T) {
r := require.New(t)
b := NewBonder(memdb.New())
b := Bonder{}

mutable := &state.SimpleMutable{}
address := codec.Address{1, 2, 3}
r.NoError(b.SetMaxBalance(address, 3))
r.NoError(b.SetMaxBalance(context.Background(), mutable, address, 3))

tx1 := &Transaction{
Auth: TestAuth{
Expand All @@ -189,22 +201,22 @@ func TestSetMaxBalanceDuringBond(t *testing.T) {
},
}

ok, err := b.Bond(tx1)
ok, err := b.Bond(context.Background(), mutable, tx1)
r.NoError(err)
r.True(ok)

ok, err = b.Bond(tx2)
ok, err = b.Bond(context.Background(), mutable, tx2)
r.NoError(err)
r.True(ok)

r.NoError(b.SetMaxBalance(address, 0))
r.NoError(b.SetMaxBalance(context.Background(), mutable, address, 0))

ok, err = b.Bond(tx3)
ok, err = b.Bond(context.Background(), mutable, tx3)
r.NoError(err)
r.False(ok)

r.NoError(b.Unbond(tx1))
r.NoError(b.Unbond(tx2))
r.NoError(b.Unbond(context.Background(), mutable, tx1))
r.NoError(b.Unbond(context.Background(), mutable, tx2))
}

type TestAuth struct {
Expand Down
15 changes: 8 additions & 7 deletions x/fdsmr/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/ava-labs/hypersdk/codec"
"github.com/ava-labs/hypersdk/internal/eheap"
"github.com/ava-labs/hypersdk/state"
"github.com/ava-labs/hypersdk/x/dsmr"
)

Expand All @@ -19,11 +20,11 @@ type DSMR[T dsmr.Tx] interface {
type Bonder[T dsmr.Tx] interface {
// Bond returns if a transaction can be built into a chunk by this node.
// If this returns true, Unbond is guaranteed to be called.
Bond(tx T) (bool, error)
Bond(ctx context.Context, mutable state.Mutable, tx T) (bool, error)
// Unbond is called when a tx from an account either expires or is accepted.
// If Unbond is called, Bond is guaranteed to have been called previously on
// tx.
Unbond(tx T) error
Unbond(ctx context.Context, mutable state.Mutable, tx T) error
}

// New returns a fortified instance of DSMR
Expand All @@ -42,10 +43,10 @@ type Node[T DSMR[U], U dsmr.Tx] struct {
pending *eheap.ExpiryHeap[U]
}

func (n *Node[T, U]) BuildChunk(ctx context.Context, txs []U, expiry int64, beneficiary codec.Address) error {
func (n *Node[T, U]) BuildChunk(ctx context.Context, mutable state.Mutable, txs []U, expiry int64, beneficiary codec.Address) error {
bonded := make([]U, 0, len(txs))
for _, tx := range txs {
ok, err := n.bonder.Bond(tx)
ok, err := n.bonder.Bond(ctx, mutable, tx)
if err != nil {
return err
}
Expand All @@ -60,7 +61,7 @@ func (n *Node[T, U]) BuildChunk(ctx context.Context, txs []U, expiry int64, bene
return n.DSMR.BuildChunk(ctx, bonded, expiry, beneficiary)
}

func (n *Node[T, U]) Accept(ctx context.Context, block dsmr.Block) (dsmr.ExecutedBlock[U], error) {
func (n *Node[T, U]) Accept(ctx context.Context, mutable state.Mutable, block dsmr.Block) (dsmr.ExecutedBlock[U], error) {
executedBlock, err := n.DSMR.Accept(ctx, block)
if err != nil {
return dsmr.ExecutedBlock[U]{}, err
Expand All @@ -69,7 +70,7 @@ func (n *Node[T, U]) Accept(ctx context.Context, block dsmr.Block) (dsmr.Execute
// Un-bond any txs that expired at this block
expired := n.pending.SetMin(block.Timestamp)
for _, tx := range expired {
if err := n.bonder.Unbond(tx); err != nil {
if err := n.bonder.Unbond(ctx, mutable, tx); err != nil {
return dsmr.ExecutedBlock[U]{}, err
}
}
Expand All @@ -81,7 +82,7 @@ func (n *Node[T, U]) Accept(ctx context.Context, block dsmr.Block) (dsmr.Execute
continue
}

if err := n.bonder.Unbond(tx); err != nil {
if err := n.bonder.Unbond(ctx, mutable, tx); err != nil {
return dsmr.ExecutedBlock[U]{}, err
}
}
Expand Down
Loading

0 comments on commit ff62e92

Please sign in to comment.