diff --git a/auth/simple_token.go b/auth/simple_token.go index 5b608af92..ff48c5140 100644 --- a/auth/simple_token.go +++ b/auth/simple_token.go @@ -32,27 +32,26 @@ import ( const ( letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" defaultSimpleTokenLength = 16 +) + +// var for testing purposes +var ( simpleTokenTTL = 5 * time.Minute simpleTokenTTLResolution = 1 * time.Second ) type simpleTokenTTLKeeper struct { - tokens map[string]time.Time - addSimpleTokenCh chan string - resetSimpleTokenCh chan string - deleteSimpleTokenCh chan string - stopCh chan chan struct{} - deleteTokenFunc func(string) + tokensMu sync.Mutex + tokens map[string]time.Time + stopCh chan chan struct{} + deleteTokenFunc func(string) } func NewSimpleTokenTTLKeeper(deletefunc func(string)) *simpleTokenTTLKeeper { stk := &simpleTokenTTLKeeper{ - tokens: make(map[string]time.Time), - addSimpleTokenCh: make(chan string, 1), - resetSimpleTokenCh: make(chan string, 1), - deleteSimpleTokenCh: make(chan string, 1), - stopCh: make(chan chan struct{}), - deleteTokenFunc: deletefunc, + tokens: make(map[string]time.Time), + stopCh: make(chan chan struct{}), + deleteTokenFunc: deletefunc, } go stk.run() return stk @@ -66,37 +65,34 @@ func (tm *simpleTokenTTLKeeper) stop() { } func (tm *simpleTokenTTLKeeper) addSimpleToken(token string) { - tm.addSimpleTokenCh <- token + tm.tokens[token] = time.Now().Add(simpleTokenTTL) } func (tm *simpleTokenTTLKeeper) resetSimpleToken(token string) { - tm.resetSimpleTokenCh <- token + if _, ok := tm.tokens[token]; ok { + tm.tokens[token] = time.Now().Add(simpleTokenTTL) + } } func (tm *simpleTokenTTLKeeper) deleteSimpleToken(token string) { - tm.deleteSimpleTokenCh <- token + delete(tm.tokens, token) } + func (tm *simpleTokenTTLKeeper) run() { tokenTicker := time.NewTicker(simpleTokenTTLResolution) defer tokenTicker.Stop() for { select { - case t := <-tm.addSimpleTokenCh: - tm.tokens[t] = time.Now().Add(simpleTokenTTL) - case t := <-tm.resetSimpleTokenCh: - if _, ok := tm.tokens[t]; ok { - tm.tokens[t] = time.Now().Add(simpleTokenTTL) - } - case t := <-tm.deleteSimpleTokenCh: - delete(tm.tokens, t) case <-tokenTicker.C: nowtime := time.Now() + tm.tokensMu.Lock() for t, tokenendtime := range tm.tokens { if nowtime.After(tokenendtime) { tm.deleteTokenFunc(t) delete(tm.tokens, t) } } + tm.tokensMu.Unlock() case waitCh := <-tm.stopCh: tm.tokens = make(map[string]time.Time) waitCh <- struct{}{} @@ -108,7 +104,7 @@ func (tm *simpleTokenTTLKeeper) run() { type tokenSimple struct { indexWaiter func(uint64) <-chan struct{} simpleTokenKeeper *simpleTokenTTLKeeper - simpleTokensMu sync.RWMutex + simpleTokensMu sync.Mutex simpleTokens map[string]string // token -> username } @@ -128,6 +124,7 @@ func (t *tokenSimple) genTokenPrefix() (string, error) { } func (t *tokenSimple) assignSimpleTokenToUser(username, token string) { + t.simpleTokenKeeper.tokensMu.Lock() t.simpleTokensMu.Lock() _, ok := t.simpleTokens[token] @@ -138,18 +135,23 @@ func (t *tokenSimple) assignSimpleTokenToUser(username, token string) { t.simpleTokens[token] = username t.simpleTokenKeeper.addSimpleToken(token) t.simpleTokensMu.Unlock() + t.simpleTokenKeeper.tokensMu.Unlock() } func (t *tokenSimple) invalidateUser(username string) { + if t.simpleTokenKeeper == nil { + return + } + t.simpleTokenKeeper.tokensMu.Lock() t.simpleTokensMu.Lock() - defer t.simpleTokensMu.Unlock() - for token, name := range t.simpleTokens { if strings.Compare(name, username) == 0 { delete(t.simpleTokens, token) t.simpleTokenKeeper.deleteSimpleToken(token) } } + t.simpleTokensMu.Unlock() + t.simpleTokenKeeper.tokensMu.Unlock() } func newDeleterFunc(t *tokenSimple) func(string) { @@ -172,7 +174,6 @@ func (t *tokenSimple) disable() { t.simpleTokenKeeper.stop() t.simpleTokenKeeper = nil } - t.simpleTokensMu.Lock() t.simpleTokens = make(map[string]string) // invalidate all tokens t.simpleTokensMu.Unlock() @@ -182,14 +183,14 @@ func (t *tokenSimple) info(ctx context.Context, token string, revision uint64) ( if !t.isValidSimpleToken(ctx, token) { return nil, false } - - t.simpleTokensMu.RLock() - defer t.simpleTokensMu.RUnlock() + t.simpleTokenKeeper.tokensMu.Lock() + t.simpleTokensMu.Lock() username, ok := t.simpleTokens[token] if ok { t.simpleTokenKeeper.resetSimpleToken(token) } - + t.simpleTokensMu.Unlock() + t.simpleTokenKeeper.tokensMu.Unlock() return &AuthInfo{Username: username, Revision: revision}, ok }