also use the multiplexer for the server

This commit is contained in:
Marten Seemann 2018-07-20 08:26:36 -04:00
parent c8d20e86d7
commit ad5a3e2fa0
15 changed files with 631 additions and 512 deletions

View file

@ -14,7 +14,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -27,6 +26,8 @@ type mockSession struct {
runner sessionRunner
}
func (s *mockSession) GetPerspective() protocol.Perspective { panic("not implemented") }
var _ = Describe("Server", func() {
var (
conn *mockPacketConn
@ -89,7 +90,7 @@ var _ = Describe("Server", func() {
Context("with mock session", func() {
var (
serv *server
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
firstPacket *receivedPacket
connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
sessions = make([]*MockQuicSession, 0)
sessionHandler *MockPacketHandlerManager
@ -126,9 +127,16 @@ var _ = Describe("Server", func() {
serv.setup()
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]))
firstPacket = []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
firstPacket = append(append(firstPacket, b.Bytes()...), 0x01)
firstPacket = append(firstPacket, bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)...) // add padding
firstPacket = &receivedPacket{
header: &wire.Header{
VersionFlag: true,
Version: serv.config.Versions[0],
DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6},
PacketNumber: 1,
},
data: bytes.Repeat([]byte{0}, protocol.MinClientHelloSize),
rcvTime: time.Now(),
}
})
AfterEach(func() {
@ -150,12 +158,10 @@ var _ = Describe("Server", func() {
s.EXPECT().run().Do(func() { close(run) })
sessions = append(sessions, s)
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
Expect(sess.(*mockSession).connID).To(Equal(connID))
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(cid protocol.ConnectionID, _ packetHandler) {
Expect(cid).To(Equal(connID))
})
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
Eventually(run).Should(BeClosed())
})
@ -165,7 +171,8 @@ var _ = Describe("Server", func() {
err := serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
added := make(chan struct{})
sessionHandler.EXPECT().Add(connID, sess).Do(func(protocol.ConnectionID, packetHandler) {
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, ph packetHandler) {
Expect(ph.GetPerspective()).To(Equal(protocol.PerspectiveServer))
close(added)
})
serv.serverTLS.sessionChan <- tlsSession{
@ -184,17 +191,15 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := serv.Accept()
_, err := serv.Accept()
Expect(err).ToNot(HaveOccurred())
Expect(sess.(*mockSession).connID).To(Equal(connID))
close(done)
}()
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
Consistently(done).ShouldNot(BeClosed())
sess.(*mockSession).runner.onHandshakeComplete(sess.(Session))
sess.(*serverSession).quicSession.(*mockSession).runner.onHandshakeComplete(sess.(Session))
})
err := serv.handlePacket(nil, firstPacket)
err := serv.handlePacketImpl(firstPacket)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Eventually(run).Should(BeClosed())
@ -212,45 +217,20 @@ var _ = Describe("Server", func() {
serv.Accept()
close(done)
}()
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(protocol.ConnectionID, packetHandler) {
run <- errors.New("handshake error")
})
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
sessionHandler.EXPECT().Close()
close(serv.errorChan)
serv.Close()
Eventually(done).Should(BeClosed())
})
It("assigns packets to existing sessions", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().GetVersion()
sessionHandler.EXPECT().Get(connID).Return(sess, true)
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
Expect(err).ToNot(HaveOccurred())
})
It("closes the sessionHandler and the connection when Close is called", func() {
go func() {
defer GinkgoRecover()
serv.serve()
}()
// close the server
sessionHandler.EXPECT().Close().AnyTimes()
It("closes the sessionHandler when Close is called", func() {
sessionHandler.EXPECT().CloseServer()
Expect(serv.Close()).To(Succeed())
Expect(conn.closed).To(BeTrue())
})
It("ignores packets for closed sessions", func() {
sessionHandler.EXPECT().Get(connID).Return(nil, true)
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
})
It("works if no quic.Config is given", func(done Done) {
@ -264,163 +244,56 @@ var _ = Describe("Server", func() {
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
var returned bool
go func() {
defer GinkgoRecover()
_, err := ln.Accept()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
returned = true
}()
ln.Close()
Eventually(func() bool { return returned }).Should(BeTrue())
})
It("errors when encountering a connection error", func() {
testErr := errors.New("connection error")
conn.readErr = testErr
sessionHandler.EXPECT().Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.serve()
ln.Accept()
close(done)
}()
_, err := serv.Accept()
Expect(err).To(MatchError(testErr))
ln.Close()
Eventually(done).Should(BeClosed())
})
It("ignores delayed packets with mismatching versions", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion()
// don't EXPECT any handlePacket() calls to this session
sessionHandler.EXPECT().Get(connID).Return(sess, true)
b := &bytes.Buffer{}
// add an unsupported version
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1))
data = append(append(data, b.Bytes()...), 0x01)
err := serv.handlePacket(nil, data)
Expect(err).ToNot(HaveOccurred())
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
It("returns Accept when it is closed", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept()
Expect(err).To(MatchError("server closed"))
close(done)
}()
sessionHandler.EXPECT().CloseServer()
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("errors on invalid public header", func() {
err := serv.handlePacket(nil, nil)
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
})
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
sessionHandler.EXPECT().Get(connID).Return(sess, true)
serv.supportsTLS = true
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 1000,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).To(MatchError("packet payload (456 bytes) is smaller than the expected payload length (1000 bytes)"))
})
It("cuts packets at the payload length", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
Expect(packet.data).To(HaveLen(123))
})
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
sessionHandler.EXPECT().Get(connID).Return(sess, true)
serv.supportsTLS = true
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).ToNot(HaveOccurred())
})
It("drops packets with invalid packet types", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
sessionHandler.EXPECT().Get(connID).Return(sess, true)
serv.supportsTLS = true
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).To(MatchError("Received unsupported packet type: Retry"))
})
It("ignores Public Resets", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
sessionHandler.EXPECT().Get(connID).Return(sess, true)
err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
Expect(err).ToNot(HaveOccurred())
It("returns Accept with the right error when closeWithError is called", func() {
testErr := errors.New("connection error")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept()
Expect(err).To(MatchError(testErr))
close(done)
}()
sessionHandler.EXPECT().CloseServer()
serv.closeWithError(testErr)
Eventually(done).Should(BeClosed())
})
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
config.Versions = []protocol.VersionNumber{99}
b := &bytes.Buffer{}
hdr := wire.Header{
VersionFlag: true,
DestConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
p := &receivedPacket{
header: &wire.Header{
VersionFlag: true,
DestConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
},
data: make([]byte, protocol.MinClientHelloSize),
}
Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
serv.conn = conn
sessionHandler.EXPECT().Get(connID)
err := serv.handlePacket(nil, b.Bytes())
Expect(serv.handlePacketImpl(p)).To(Succeed())
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
Expect(err).ToNot(HaveOccurred())
})
It("doesn't respond with a version negotiation packet if the first packet is too small", func() {
b := &bytes.Buffer{}
hdr := wire.Header{
VersionFlag: true,
DestConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small
serv.conn = conn
sessionHandler.EXPECT().Get(connID)
err := serv.handlePacket(udpAddr, b.Bytes())
Expect(err).To(MatchError("dropping small packet with unknown version"))
Expect(conn.dataWritten.Len()).Should(BeZero())
})
})
@ -523,8 +396,11 @@ var _ = Describe("Server", func() {
})
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
config.Versions = append(config.Versions, protocol.VersionTLS)
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
b := &bytes.Buffer{}
hdr := wire.Header{
Type: protocol.PacketTypeInitial,
@ -536,13 +412,10 @@ var _ = Describe("Server", func() {
Version: 0x1234,
PayloadLen: protocol.MinInitialPacketSize,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)).To(Succeed())
b.Write(bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) // add a fake CHLO
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -568,51 +441,6 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
b := &bytes.Buffer{}
hdr := wire.Header{
Type: protocol.PacketTypeInitial,
IsLongHeader: true,
DestConnectionID: connID,
SrcConnectionID: connID,
PacketNumber: 0x55,
PacketNumberLen: protocol.PacketNumberLen1,
Version: protocol.VersionTLS,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
})
It("ignores non-Initial Long Header packets for unknown connections", func() {
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
b := &bytes.Buffer{}
hdr := wire.Header{
Type: protocol.PacketTypeHandshake,
IsLongHeader: true,
DestConnectionID: connID,
SrcConnectionID: connID,
PacketNumber: 0x55,
PacketNumberLen: protocol.PacketNumberLen1,
Version: protocol.VersionTLS,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
})
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
conn.dataReadFrom = udpAddr
conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}