diff --git a/gonsensus.go b/gonsensus.go index f1fb977..19dcf81 100644 --- a/gonsensus.go +++ b/gonsensus.go @@ -61,6 +61,8 @@ type Manager struct { callbackMu sync.RWMutex onElected func(context.Context) error onDemoted func(context.Context) + + lease *Lease } func NewManager(client S3Client, bucket string, cfg Config) (*Manager, error) { @@ -102,6 +104,7 @@ func NewManager(client S3Client, bucket string, cfg Config) (*Manager, error) { lockKey: lockPrefix + "leader", ttl: cfg.TTL, pollInterval: cfg.PollInterval, + lease: NewLease(), }, nil } @@ -139,13 +142,14 @@ func (m *Manager) acquireLock(ctx context.Context) error { // Lock doesn't exist or is expired, try to acquire it newTerm := m.incrementTerm() // Increment term for new leadership attempt + newVersion := fmt.Sprintf("%d-%s-%d", now.UnixNano(), m.nodeID, newTerm) lockInfo := LockInfo{ Node: m.nodeID, Timestamp: now, Expiry: now.Add(m.ttl), Term: newTerm, - Version: fmt.Sprintf("%d-%s-%d", now.UnixNano(), m.nodeID, newTerm), + Version: newVersion, } lockData, err := json.Marshal(lockInfo) @@ -218,6 +222,9 @@ func (m *Manager) acquireLock(ctx context.Context) error { return fmt.Errorf("failed to acquire lock: %w", err) } + // Update lease information + m.lease.UpdateLease(&lockInfo) + // Clean up our attempt _, _ = m.s3Client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(m.bucket), @@ -249,20 +256,36 @@ func (m *Manager) renewLock(ctx context.Context) error { return fmt.Errorf("failed to decode lock info: %w", err) } - // Verify we still own the lock - if currentLock.Node != m.nodeID || currentLock.Term != m.getCurrentTerm() { - return ErrLockModified + // Get current lease info + currentLease := m.lease.GetLeaseInfo() + + // If we have a current lease, verify everything matches + if currentLease != nil { + if currentLock.Node != m.nodeID || + currentLock.Term != currentLease.Term || + currentLock.Version != currentLease.Version { + return ErrLockModified + } + } else { + // If we don't have a lease but the lock exists and belongs to us, + // adopt it (this handles the initial renewal case) + if currentLock.Node == m.nodeID && currentLock.Term == m.getCurrentTerm() { + m.lease.UpdateLease(¤tLock) + } else { + return ErrLockModified + } } // Create new lock info with updated timestamp and version now := time.Now() - curTerm := m.getCurrentTerm() + currentTerm := m.getCurrentTerm() + newVersion := fmt.Sprintf("%d-%s-%d", now.UnixNano(), m.nodeID, currentTerm) newLock := LockInfo{ Node: m.nodeID, Timestamp: now, Expiry: now.Add(m.ttl), - Term: curTerm, - Version: fmt.Sprintf("%d-%s-%d", now.UnixNano(), m.nodeID, curTerm), + Term: currentTerm, + Version: newVersion, } lockData, err := json.Marshal(newLock) @@ -271,7 +294,7 @@ func (m *Manager) renewLock(ctx context.Context) error { } // Create a new key for the update - updateKey := fmt.Sprintf("%s.%s", m.lockKey, newLock.Version) + updateKey := fmt.Sprintf("%s.%s", m.lockKey, newVersion) // Attempt to create new version input := &s3.PutObjectInput{ @@ -306,6 +329,9 @@ func (m *Manager) renewLock(ctx context.Context) error { return fmt.Errorf("failed to update main lock: %w", err) } + // Update lease information + m.lease.UpdateLease(&newLock) + // Clean up temporary key _, _ = m.s3Client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(m.bucket), @@ -365,6 +391,47 @@ func (m *Manager) GetLockInfo(ctx context.Context) (*LockInfo, error) { return &lockInfo, nil } +// Lease represents the current leadership lease. +type Lease struct { + mu sync.RWMutex + info *LockInfo + version atomic.Value // stores string +} + +// NewLease creates a new lease instance. +func NewLease() *Lease { + l := &Lease{} + l.version.Store("") + + return l +} + +// UpdateLease updates the lease information atomically. +func (l *Lease) UpdateLease(info *LockInfo) { + l.mu.Lock() + defer l.mu.Unlock() + l.info = info + l.version.Store(info.Version) +} + +// GetCurrentVersion returns the current lease version. +func (l *Lease) GetCurrentVersion() string { + s, ok := l.version.Load().(string) + if !ok { + panic("forcetypeassert") + } + + return s +} + +// GetLeaseInfo returns the current lease information. +func (l *Lease) GetLeaseInfo() *LockInfo { + l.mu.RLock() + defer l.mu.RUnlock() + + return l.info +} + type LockInfo struct { Node string `json:"node"` Timestamp time.Time `json:"timestamp"` diff --git a/gonsensus_test.go b/gonsensus_test.go index f7b6929..2183508 100644 --- a/gonsensus_test.go +++ b/gonsensus_test.go @@ -262,6 +262,10 @@ func TestRenewLock(t *testing.T) { Term: mgr.getCurrentTerm(), Version: "1", } + + // Initialize the lease with the current lock info + mgr.lease.UpdateLease(&lock) + data, err := json.Marshal(lock) if err != nil { log.Panic("mock setup fail") @@ -280,14 +284,26 @@ func TestRenewLock(t *testing.T) { { name: "Lock modified by other node", setupMock: func(mockClient *MockS3Client, mgr *Manager) { - lock := LockInfo{ + originalLock := LockInfo{ + Node: "other-node", + Timestamp: time.Now(), + Expiry: time.Now().Add(30 * time.Second), + Term: mgr.incrementTerm(), + Version: "1", + } + + mgr.lease.UpdateLease(&originalLock) + + // Then simulate modification by another node + modifiedLock := LockInfo{ Node: "other-node", Timestamp: time.Now(), Expiry: time.Now().Add(30 * time.Second), Term: mgr.incrementTerm(), Version: "2", } - data, err := json.Marshal(lock) + + data, err := json.Marshal(modifiedLock) if err != nil { log.Panic("mock setup fail") }