diff --git a/interface.go b/interface.go index f35b86dc..337d3043 100644 --- a/interface.go +++ b/interface.go @@ -27,6 +27,23 @@ type Token struct { SentTime time.Time } +// A ClientToken is a token received by the client. +// It can be used to skip address validation on future connection attempts. +type ClientToken struct { + data []byte +} + +type TokenStore interface { + // Pop searches for a ClientToken associated with the given key. + // Since tokens are not supposed to be reused, it must remove the token from the cache. + // It returns nil when no token is found. + Pop(key string) (token *ClientToken) + + // Put adds a token to the cache with the given key. It might get called + // multiple times in a connection. + Put(key string, token *ClientToken) +} + // An ErrorCode is an application-defined error code. // Valid values range between 0 and MAX_UINT62. type ErrorCode = protocol.ApplicationErrorCode diff --git a/token_store.go b/token_store.go new file mode 100644 index 00000000..9641dc5a --- /dev/null +++ b/token_store.go @@ -0,0 +1,117 @@ +package quic + +import ( + "container/list" + "sync" + + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type singleOriginTokenStore struct { + tokens []*ClientToken + len int + p int +} + +func newSingleOriginTokenStore(size int) *singleOriginTokenStore { + return &singleOriginTokenStore{tokens: make([]*ClientToken, size)} +} + +func (s *singleOriginTokenStore) Add(token *ClientToken) { + s.tokens[s.p] = token + s.p = s.index(s.p + 1) + s.len = utils.Min(s.len+1, len(s.tokens)) +} + +func (s *singleOriginTokenStore) Pop() *ClientToken { + s.p = s.index(s.p - 1) + token := s.tokens[s.p] + s.tokens[s.p] = nil + s.len = utils.Max(s.len-1, 0) + return token +} + +func (s *singleOriginTokenStore) Len() int { + return s.len +} + +func (s *singleOriginTokenStore) index(i int) int { + mod := len(s.tokens) + return (i + mod) % mod +} + +type lruTokenStoreEntry struct { + key string + cache *singleOriginTokenStore +} + +type lruTokenStore struct { + mutex sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int + singleOriginSize int +} + +var _ TokenStore = &lruTokenStore{} + +// NewLRUTokenStore creates a new LRU cache for tokens received by the client. +// maxOrigins specifies how many origins this cache is saving tokens for. +// tokensPerOrigin specifies the maximum number of tokens per origin. +func NewLRUTokenStore(maxOrigins, tokensPerOrigin int) TokenStore { + return &lruTokenStore{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: maxOrigins, + singleOriginSize: tokensPerOrigin, + } +} + +func (s *lruTokenStore) Put(key string, token *ClientToken) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if el, ok := s.m[key]; ok { + entry := el.Value.(*lruTokenStoreEntry) + entry.cache.Add(token) + s.q.MoveToFront(el) + return + } + + if s.q.Len() < s.capacity { + entry := &lruTokenStoreEntry{ + key: key, + cache: newSingleOriginTokenStore(s.singleOriginSize), + } + entry.cache.Add(token) + s.m[key] = s.q.PushFront(entry) + return + } + + elem := s.q.Back() + entry := elem.Value.(*lruTokenStoreEntry) + delete(s.m, entry.key) + entry.key = key + entry.cache = newSingleOriginTokenStore(s.singleOriginSize) + entry.cache.Add(token) + s.q.MoveToFront(elem) + s.m[key] = elem +} + +func (s *lruTokenStore) Pop(key string) *ClientToken { + s.mutex.Lock() + defer s.mutex.Unlock() + + var token *ClientToken + if el, ok := s.m[key]; ok { + s.q.MoveToFront(el) + cache := el.Value.(*lruTokenStoreEntry).cache + token = cache.Pop() + if cache.Len() == 0 { + s.q.Remove(el) + delete(s.m, key) + } + } + return token +} diff --git a/token_store_test.go b/token_store_test.go new file mode 100644 index 00000000..01107821 --- /dev/null +++ b/token_store_test.go @@ -0,0 +1,108 @@ +package quic + +import ( + "fmt" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Token Cache", func() { + var s TokenStore + + BeforeEach(func() { + s = NewLRUTokenStore(3, 4) + }) + + mockToken := func(num int) *ClientToken { + return &ClientToken{data: []byte(fmt.Sprintf("%d", num))} + } + + Context("for a single origin", func() { + const origin = "localhost" + + It("adds and gets tokens", func() { + s.Put(origin, mockToken(1)) + s.Put(origin, mockToken(2)) + Expect(s.Pop(origin)).To(Equal(mockToken(2))) + Expect(s.Pop(origin)).To(Equal(mockToken(1))) + Expect(s.Pop(origin)).To(BeNil()) + }) + + It("overwrites old tokens", func() { + s.Put(origin, mockToken(1)) + s.Put(origin, mockToken(2)) + s.Put(origin, mockToken(3)) + s.Put(origin, mockToken(4)) + s.Put(origin, mockToken(5)) + Expect(s.Pop(origin)).To(Equal(mockToken(5))) + Expect(s.Pop(origin)).To(Equal(mockToken(4))) + Expect(s.Pop(origin)).To(Equal(mockToken(3))) + Expect(s.Pop(origin)).To(Equal(mockToken(2))) + Expect(s.Pop(origin)).To(BeNil()) + }) + + It("continues after getting a token", func() { + s.Put(origin, mockToken(1)) + s.Put(origin, mockToken(2)) + s.Put(origin, mockToken(3)) + Expect(s.Pop(origin)).To(Equal(mockToken(3))) + s.Put(origin, mockToken(4)) + s.Put(origin, mockToken(5)) + Expect(s.Pop(origin)).To(Equal(mockToken(5))) + Expect(s.Pop(origin)).To(Equal(mockToken(4))) + Expect(s.Pop(origin)).To(Equal(mockToken(2))) + Expect(s.Pop(origin)).To(Equal(mockToken(1))) + Expect(s.Pop(origin)).To(BeNil()) + }) + }) + + Context("for multiple origins", func() { + It("adds and gets tokens", func() { + s.Put("host1", mockToken(1)) + s.Put("host2", mockToken(2)) + Expect(s.Pop("host1")).To(Equal(mockToken(1))) + Expect(s.Pop("host1")).To(BeNil()) + Expect(s.Pop("host2")).To(Equal(mockToken(2))) + Expect(s.Pop("host2")).To(BeNil()) + }) + + It("evicts old entries", func() { + s.Put("host1", mockToken(1)) + s.Put("host2", mockToken(2)) + s.Put("host3", mockToken(3)) + s.Put("host4", mockToken(4)) + Expect(s.Pop("host1")).To(BeNil()) + Expect(s.Pop("host2")).To(Equal(mockToken(2))) + Expect(s.Pop("host3")).To(Equal(mockToken(3))) + Expect(s.Pop("host4")).To(Equal(mockToken(4))) + }) + + It("moves old entries to the front, when new tokens are added", func() { + s.Put("host1", mockToken(1)) + s.Put("host2", mockToken(2)) + s.Put("host3", mockToken(3)) + s.Put("host1", mockToken(11)) + // make sure one is evicted + s.Put("host4", mockToken(4)) + Expect(s.Pop("host2")).To(BeNil()) + Expect(s.Pop("host1")).To(Equal(mockToken(11))) + Expect(s.Pop("host1")).To(Equal(mockToken(1))) + Expect(s.Pop("host3")).To(Equal(mockToken(3))) + Expect(s.Pop("host4")).To(Equal(mockToken(4))) + }) + + It("deletes hosts that are empty", func() { + s.Put("host1", mockToken(1)) + s.Put("host2", mockToken(2)) + s.Put("host3", mockToken(3)) + Expect(s.Pop("host2")).To(Equal(mockToken(2))) + Expect(s.Pop("host2")).To(BeNil()) + // host2 is now empty and should have been deleted, making space for host4 + s.Put("host4", mockToken(4)) + Expect(s.Pop("host1")).To(Equal(mockToken(1))) + Expect(s.Pop("host3")).To(Equal(mockToken(3))) + Expect(s.Pop("host4")).To(Equal(mockToken(4))) + }) + }) +})