Skip session ID 0 in grumble/sessionpool.

This commit is contained in:
Mikkel Krautz 2011-05-14 00:09:51 +02:00
parent c5418d0464
commit cd726560d8
2 changed files with 21 additions and 20 deletions

View file

@ -16,7 +16,7 @@ type SessionPool struct {
mutex sync.Mutex mutex sync.Mutex
used map[uint32]bool used map[uint32]bool
unused []uint32 unused []uint32
next uint32 cur uint32
} }
// Create a new SessionPool container. // Create a new SessionPool container.
@ -34,7 +34,7 @@ func New() (pool *SessionPool) {
// the program will panic. // the program will panic.
// panic. // panic.
func (pool *SessionPool) EnableUseTracking() { func (pool *SessionPool) EnableUseTracking() {
if len(pool.unused) != 0 || pool.next != 0 { if len(pool.unused) != 0 || pool.cur != 0 {
panic("Attempt to enable use tracking on an existing SessionPool.") panic("Attempt to enable use tracking on an existing SessionPool.")
} }
pool.used = make(map[uint32]bool) pool.used = make(map[uint32]bool)
@ -61,17 +61,18 @@ func (pool *SessionPool) Get() (id uint32) {
return return
} }
// Check for session pool depletion. Note that this depletion // Check for depletion. If cur is MaxUint32,
// check makes MaxUint32 an invalid next value, and thus limits // there aren't any session IDs left, since the
// the session pool to 2**32-2 distinct sessions. // increment below would overflow us back to 0.
if pool.next == math.MaxUint32 { if pool.cur == math.MaxUint32 {
panic("SessionPool depleted") panic("SessionPool depleted")
} }
// Return the current 'next' value and increment it // Increment the next session id and return it.
// for next time we're here. // Note: By incrementing and *then* returning, we skip 0.
id = pool.next // This is deliberate, as 0 is an invalid session ID in Mumble.
pool.next += 1 pool.cur += 1
id = pool.cur
return return
} }

View file

@ -8,20 +8,20 @@ import (
func TestReclaim(t *testing.T) { func TestReclaim(t *testing.T) {
pool := New() pool := New()
id := pool.Get() id := pool.Get()
if id != 0 { if id != 1 {
t.Errorf("Got %v, expected 0 (first time)", id) t.Errorf("Got %v, expected 1 (first time)", id)
} }
pool.Reclaim(0) pool.Reclaim(1)
id = pool.Get()
if id != 0 {
t.Errorf("Got %v, expected 0 (second time)", id)
}
id = pool.Get() id = pool.Get()
if id != 1 { if id != 1 {
t.Errorf("Got %v, expected 1", id) t.Errorf("Got %v, expected 1 (second time)", id)
}
id = pool.Get()
if id != 2 {
t.Errorf("Got %v, expected 2", id)
} }
} }
@ -33,7 +33,7 @@ func TestDepletion(t *testing.T) {
} }
}() }()
pool := New() pool := New()
pool.next = math.MaxUint32 pool.cur = math.MaxUint32
pool.Get() pool.Get()
} }