diff --git a/lease/lessor.go b/lease/lessor.go index 4306518ec..3c9333268 100644 --- a/lease/lessor.go +++ b/lease/lessor.go @@ -17,6 +17,7 @@ package lease import ( "encoding/binary" "fmt" + "math" "sync" "time" @@ -74,14 +75,7 @@ func NewLessor(lessorID uint8, b backend.Backend, dr DeleteableRange) *lessor { dr: dr, idgen: idutil.NewGenerator(lessorID, time.Now()), } - - tx := l.b.BatchTx() - tx.Lock() - tx.UnsafeCreateBucket(leaseBucketName) - tx.Unlock() - l.b.ForceCommit() - - // TODO: recover from previous state in backend. + l.initAndRecover() return l } @@ -194,6 +188,32 @@ func (le *lessor) get(id LeaseID) *lease { return le.leaseMap[id] } +func (le *lessor) initAndRecover() { + tx := le.b.BatchTx() + tx.Lock() + defer tx.Unlock() + + tx.UnsafeCreateBucket(leaseBucketName) + _, vs := tx.UnsafeRange(leaseBucketName, int64ToBytes(0), int64ToBytes(math.MaxInt64), 0) + // TODO: copy vs and do decoding outside tx lock if lock contention becomes an issue. + for i := range vs { + var lpb leasepb.Lease + err := lpb.Unmarshal(vs[i]) + if err != nil { + panic("failed to unmarshal lease proto item") + } + id := LeaseID(lpb.ID) + le.leaseMap[id] = &lease{ + id: id, + ttl: lpb.TTL, + + // itemSet will be filled in when recover key-value pairs + expiry: minExpiry(time.Now(), time.Now().Add(time.Second*time.Duration(lpb.TTL))), + } + } + le.b.ForceCommit() +} + type lease struct { id LeaseID ttl int64 // time to live in seconds diff --git a/lease/lessor_test.go b/lease/lessor_test.go index eebd72923..9eae6b310 100644 --- a/lease/lessor_test.go +++ b/lease/lessor_test.go @@ -127,6 +127,30 @@ func TestLessorRenew(t *testing.T) { } } +// TestLessorRecover ensures Lessor recovers leases from +// persist backend. +func TestLessorRecover(t *testing.T) { + dir, be := NewTestBackend(t) + defer os.RemoveAll(dir) + defer be.Close() + + le := NewLessor(1, be, &fakeDeleteable{}) + l1 := le.Grant(10) + l2 := le.Grant(20) + + // Create a new lessor with the same backend + nle := NewLessor(1, be, &fakeDeleteable{}) + nl1 := nle.get(l1.id) + if nl1 == nil || nl1.ttl != l1.ttl { + t.Errorf("nl1 = %v, want nl1.TTL= %d", l1.ttl) + } + + nl2 := nle.get(l2.id) + if nl2 == nil || nl2.ttl != l2.ttl { + t.Errorf("nl2 = %v, want nl2.TTL= %d", l2.ttl) + } +} + type fakeDeleteable struct { deleted []string }