uquic/integrationtests/self/resumption_test.go
2022-10-11 16:38:44 +04:00

125 lines
3.4 KiB
Go

package self_test
import (
"context"
"crypto/tls"
"fmt"
"net"
"sync"
"github.com/lucas-clemente/quic-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
type clientSessionCache struct {
mutex sync.Mutex
cache map[string]*tls.ClientSessionState
gets chan<- string
puts chan<- string
}
func newClientSessionCache(gets, puts chan<- string) *clientSessionCache {
return &clientSessionCache{
cache: make(map[string]*tls.ClientSessionState),
gets: gets,
puts: puts,
}
}
var _ tls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
c.gets <- sessionKey
c.mutex.Lock()
session, ok := c.cache[sessionKey]
c.mutex.Unlock()
return session, ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
c.puts <- sessionKey
c.mutex.Lock()
c.cache[sessionKey] = cs
c.mutex.Unlock()
}
var _ = Describe("TLS session resumption", func() {
It("uses session resumption", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
nil,
)
Expect(err).ToNot(HaveOccurred())
var sessionKey string
Eventually(puts).Should(Receive(&sessionKey))
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
serverConn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn, err = quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
nil,
)
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive(Equal(sessionKey)))
Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue())
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
})
It("doesn't use session resumption, if the config disables it", func() {
sConf := getTLSConfig()
sConf.SessionTicketsDisabled = true
server, err := quic.ListenAddr("localhost:0", sConf, nil)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
nil,
)
Expect(err).ToNot(HaveOccurred())
Consistently(puts).ShouldNot(Receive())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
serverConn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn, err = quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
nil,
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
})
})