uquic/server_test.go
Marten Seemann 9db23eceed
reverse the order of the SupportedVersions slice
For the client, the Versions option in the quic.Config encodes the first
entry is the preferred version. If not set, this should default to the
highest supported version.
2017-05-05 18:06:14 +08:00

422 lines
14 KiB
Go

package quic
import (
"bytes"
"crypto/tls"
"errors"
"net"
"time"
"github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockSession struct {
connectionID protocol.ConnectionID
packetCount int
closed bool
closeReason error
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
}
func (s *mockSession) handlePacket(*receivedPacket) {
s.packetCount++
}
func (s *mockSession) run() error {
<-s.stopRunLoop
return s.closeReason
}
func (s *mockSession) Close(e error) error {
s.closeReason = e
s.closed = true
return nil
}
func (s *mockSession) AcceptStream() (Stream, error) {
panic("not implemented")
}
func (s *mockSession) OpenStream() (Stream, error) {
return &stream{streamID: 1337}, nil
}
func (s *mockSession) OpenStreamSync() (Stream, error) {
panic("not implemented")
}
func (s *mockSession) LocalAddr() net.Addr {
panic("not implemented")
}
func (s *mockSession) RemoteAddr() net.Addr {
panic("not implemented")
}
var _ Session = &mockSession{}
func newMockSession(_ connection, _ protocol.VersionNumber, connectionID protocol.ConnectionID, _ *handshake.ServerConfig, _ cryptoChangeCallback, _ []protocol.VersionNumber) (packetHandler, error) {
return &mockSession{
connectionID: connectionID,
stopRunLoop: make(chan struct{}),
}, nil
}
var _ = Describe("Server", func() {
var (
conn *mockPacketConn
config *Config
udpAddr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
)
BeforeEach(func() {
conn = &mockPacketConn{}
config = &Config{
TLSConfig: &tls.Config{},
Versions: protocol.SupportedVersions,
}
})
Context("with mock session", func() {
var (
serv *server
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
connID = protocol.ConnectionID(0x4cfa9f9b668619f6)
)
BeforeEach(func() {
serv = &server{
sessions: make(map[protocol.ConnectionID]packetHandler),
newSession: newMockSession,
conn: conn,
config: config,
}
b := &bytes.Buffer{}
utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.SupportedVersions[0]))
firstPacket = []byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}
firstPacket = append(append(firstPacket, b.Bytes()...), 0x01)
})
It("returns the address", func() {
conn.addr = &net.UDPAddr{
IP: net.IPv4(192, 168, 13, 37),
Port: 1234,
}
Expect(serv.Addr().String()).To(Equal("192.168.13.37:1234"))
})
It("composes version negotiation packets", func() {
expected := append(
[]byte{0x01 | 0x08, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
[]byte{'Q', '0', '9', '9'}...,
)
Expect(composeVersionNegotiation(1, []protocol.VersionNumber{99})).To(Equal(expected))
})
It("creates new sessions", func() {
var connStateCalled bool
var connStateStatus ConnState
var connStateSession Session
config.ConnState = func(s Session, state ConnState) {
connStateStatus = state
connStateSession = s
connStateCalled = true
}
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[connID].(*mockSession)
Expect(sess.connectionID).To(Equal(connID))
Expect(sess.packetCount).To(Equal(1))
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
Expect(connStateSession).To(Equal(sess))
Expect(connStateStatus).To(Equal(ConnStateVersionNegotiated))
})
It("calls the ConnState callback when the connection is secure", func() {
var connStateCalled bool
var connStateStatus ConnState
var connStateSession Session
config.ConnState = func(s Session, state ConnState) {
connStateStatus = state
connStateSession = s
connStateCalled = true
}
sess := &mockSession{}
serv.cryptoChangeCallback(sess, false)
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
Expect(connStateSession).To(Equal(sess))
Expect(connStateStatus).To(Equal(ConnStateSecure))
})
It("calls the ConnState callback when the connection is forward-secure", func() {
var connStateCalled bool
var connStateStatus ConnState
var connStateSession Session
config.ConnState = func(s Session, state ConnState) {
connStateStatus = state
connStateSession = s
connStateCalled = true
}
sess := &mockSession{}
serv.cryptoChangeCallback(sess, true)
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
Expect(connStateStatus).To(Equal(ConnStateForwardSecure))
Expect(connStateSession).To(Equal(sess))
})
It("assigns packets to existing sessions", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
err = serv.handlePacket(nil, nil, []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01})
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[connID].(*mockSession).connectionID).To(Equal(connID))
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(2))
})
It("closes and deletes sessions", func() {
serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test
nullAEAD := crypto.NewNullAEAD(protocol.PerspectiveServer, protocol.VersionWhatever)
err := serv.handlePacket(nil, nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[connID]).ToNot(BeNil())
// make session.run() return
serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{}
// The server should now have closed the session, leaving a nil value in the sessions map
Consistently(func() map[protocol.ConnectionID]packetHandler { return serv.sessions }).Should(HaveLen(1))
Expect(serv.sessions[connID]).To(BeNil())
})
It("deletes nil session entries after a wait time", func() {
serv.deleteClosedSessionsAfter = 25 * time.Millisecond
nullAEAD := crypto.NewNullAEAD(protocol.PerspectiveServer, protocol.VersionWhatever)
err := serv.handlePacket(nil, nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions).To(HaveKey(connID))
// make session.run() return
serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{}
Eventually(func() bool {
serv.sessionsMutex.Lock()
_, ok := serv.sessions[connID]
serv.sessionsMutex.Unlock()
return ok
}).Should(BeFalse())
})
It("closes sessions and the connection when Close is called", func() {
session := &mockSession{}
serv.sessions[1] = session
err := serv.Close()
Expect(err).NotTo(HaveOccurred())
Expect(session.closed).To(BeTrue())
Expect(conn.closed).To(BeTrue())
})
It("ignores packets for closed sessions", func() {
serv.sessions[connID] = nil
err := serv.handlePacket(nil, nil, []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01})
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[connID]).To(BeNil())
})
It("closes properly", func() {
ln, err := ListenAddr("127.0.0.1:0", config)
Expect(err).ToNot(HaveOccurred())
var returned bool
go func() {
defer GinkgoRecover()
err := ln.Serve()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
returned = true
}()
ln.Close()
Eventually(func() bool { return returned }).Should(BeTrue())
})
It("errors when encountering a connection error", func() {
testErr := errors.New("connection error")
conn.readErr = testErr
err := serv.Serve()
Expect(err).To(MatchError(testErr))
})
It("closes all sessions when encountering a connection error", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveKey(connID))
Expect(serv.sessions[connID].(*mockSession).closed).To(BeFalse())
testErr := errors.New("connection error")
conn.readErr = testErr
_ = serv.Serve()
Expect(serv.sessions[connID].(*mockSession).closed).To(BeTrue())
})
It("ignores delayed packets with mismatching versions", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1))
b := &bytes.Buffer{}
// add an unsupported version
utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.SupportedVersions[0]+1))
data := []byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}
data = append(append(data, b.Bytes()...), 0x01)
err = serv.handlePacket(nil, nil, data)
Expect(err).ToNot(HaveOccurred())
// if we didn't ignore the packet, the server would try to send a version negotation packet, which would make the test panic because it doesn't have a udpConn
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
// make sure the packet was *not* passed to session.handlePacket()
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1))
})
It("errors on invalid public header", func() {
err := serv.handlePacket(nil, nil, nil)
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
})
It("ignores public resets for unknown connections", func() {
err := serv.handlePacket(nil, nil, writePublicReset(999, 1, 1337))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(BeEmpty())
})
It("ignores public resets for known connections", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1))
err = serv.handlePacket(nil, nil, writePublicReset(connID, 1, 1337))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1))
})
It("ignores invalid public resets for known connections", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1))
data := writePublicReset(connID, 1, 1337)
err = serv.handlePacket(nil, nil, data[:len(data)-2])
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1))
})
It("doesn't respond with a version negotiation packet if the first packet is too small", func() {
b := &bytes.Buffer{}
hdr := PublicHeader{
VersionFlag: true,
ConnectionID: 0x1337,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
}
hdr.Write(b, 13 /* not a valid QUIC version */, protocol.PerspectiveClient)
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize-1)) // this packet is 1 byte too small
err := serv.handlePacket(conn, udpAddr, b.Bytes())
Expect(err).To(MatchError("dropping small packet with unknown version"))
Expect(conn.dataWritten.Len()).Should(BeZero())
})
})
It("setups with the right values", func() {
var connStateCallback ConnStateCallback = func(_ Session, _ ConnState) {}
supportedVersions := []protocol.VersionNumber{1, 3, 5}
config := Config{
TLSConfig: &tls.Config{},
ConnState: connStateCallback,
Versions: supportedVersions,
}
ln, err := Listen(conn, &config)
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
Expect(server.sessions).ToNot(BeNil())
Expect(server.scfg).ToNot(BeNil())
Expect(server.config.ConnState).ToNot(BeNil())
Expect(server.config.Versions).To(Equal(supportedVersions))
})
It("fills in default values if options are not set in the Config", func() {
config := Config{TLSConfig: &tls.Config{}}
ln, err := Listen(conn, &config)
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
})
It("listens on a given address", func() {
addr := "127.0.0.1:13579"
ln, err := ListenAddr(addr, config)
Expect(err).ToNot(HaveOccurred())
serv := ln.(*server)
Expect(serv.Addr().String()).To(Equal(addr))
})
It("errors if given an invalid address", func() {
addr := "127.0.0.1"
_, err := ListenAddr(addr, config)
Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
})
It("errors if given an invalid address", func() {
addr := "1.1.1.1:1111"
_, err := ListenAddr(addr, config)
Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
})
It("setups and responds with version negotiation", func() {
config.Versions = []protocol.VersionNumber{99}
b := &bytes.Buffer{}
hdr := PublicHeader{
VersionFlag: true,
ConnectionID: 0x1337,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
}
hdr.Write(b, 13 /* not a valid QUIC version */, protocol.PerspectiveClient)
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
conn.dataToRead = b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, config)
Expect(err).ToNot(HaveOccurred())
var returned bool
go func() {
ln.Serve()
returned = true
}()
Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero())
Expect(conn.dataWrittenTo).To(Equal(udpAddr))
b = &bytes.Buffer{}
utils.WriteUint32(b, protocol.VersionNumberToTag(99))
expected := append(
[]byte{0x9, 0x37, 0x13, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
b.Bytes()...,
)
Expect(conn.dataWritten.Bytes()).To(Equal(expected))
Expect(returned).To(BeFalse())
})
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
conn.dataReadFrom = udpAddr
conn.dataToRead = []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01}
ln, err := Listen(conn, config)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
err := ln.Serve()
Expect(err).ToNot(HaveOccurred())
}()
Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero())
Expect(conn.dataWrittenTo).To(Equal(udpAddr))
Expect(conn.dataWritten.Bytes()[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set
Expect(ln.(*server).sessions).To(BeEmpty())
})
})