mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
also use the multiplexer for the server
This commit is contained in:
parent
c8d20e86d7
commit
ad5a3e2fa0
15 changed files with 631 additions and 512 deletions
302
server_test.go
302
server_test.go
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -27,6 +26,8 @@ type mockSession struct {
|
|||
runner sessionRunner
|
||||
}
|
||||
|
||||
func (s *mockSession) GetPerspective() protocol.Perspective { panic("not implemented") }
|
||||
|
||||
var _ = Describe("Server", func() {
|
||||
var (
|
||||
conn *mockPacketConn
|
||||
|
@ -89,7 +90,7 @@ var _ = Describe("Server", func() {
|
|||
Context("with mock session", func() {
|
||||
var (
|
||||
serv *server
|
||||
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
|
||||
firstPacket *receivedPacket
|
||||
connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||
sessions = make([]*MockQuicSession, 0)
|
||||
sessionHandler *MockPacketHandlerManager
|
||||
|
@ -126,9 +127,16 @@ var _ = Describe("Server", func() {
|
|||
serv.setup()
|
||||
b := &bytes.Buffer{}
|
||||
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]))
|
||||
firstPacket = []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||
firstPacket = append(append(firstPacket, b.Bytes()...), 0x01)
|
||||
firstPacket = append(firstPacket, bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)...) // add padding
|
||||
firstPacket = &receivedPacket{
|
||||
header: &wire.Header{
|
||||
VersionFlag: true,
|
||||
Version: serv.config.Versions[0],
|
||||
DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6},
|
||||
PacketNumber: 1,
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinClientHelloSize),
|
||||
rcvTime: time.Now(),
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -150,12 +158,10 @@ var _ = Describe("Server", func() {
|
|||
s.EXPECT().run().Do(func() { close(run) })
|
||||
sessions = append(sessions, s)
|
||||
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||
Expect(sess.(*mockSession).connID).To(Equal(connID))
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(cid protocol.ConnectionID, _ packetHandler) {
|
||||
Expect(cid).To(Equal(connID))
|
||||
})
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -165,7 +171,8 @@ var _ = Describe("Server", func() {
|
|||
err := serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
added := make(chan struct{})
|
||||
sessionHandler.EXPECT().Add(connID, sess).Do(func(protocol.ConnectionID, packetHandler) {
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, ph packetHandler) {
|
||||
Expect(ph.GetPerspective()).To(Equal(protocol.PerspectiveServer))
|
||||
close(added)
|
||||
})
|
||||
serv.serverTLS.sessionChan <- tlsSession{
|
||||
|
@ -184,17 +191,15 @@ var _ = Describe("Server", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := serv.Accept()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.(*mockSession).connID).To(Equal(connID))
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
sess.(*mockSession).runner.onHandshakeComplete(sess.(Session))
|
||||
sess.(*serverSession).quicSession.(*mockSession).runner.onHandshakeComplete(sess.(Session))
|
||||
})
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
err := serv.handlePacketImpl(firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(run).Should(BeClosed())
|
||||
|
@ -212,45 +217,20 @@ var _ = Describe("Server", func() {
|
|||
serv.Accept()
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(protocol.ConnectionID, packetHandler) {
|
||||
run <- errors.New("handshake error")
|
||||
})
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
sessionHandler.EXPECT().Close()
|
||||
close(serv.errorChan)
|
||||
serv.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("assigns packets to existing sessions", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
sess.EXPECT().GetVersion()
|
||||
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("closes the sessionHandler and the connection when Close is called", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.serve()
|
||||
}()
|
||||
// close the server
|
||||
sessionHandler.EXPECT().Close().AnyTimes()
|
||||
It("closes the sessionHandler when Close is called", func() {
|
||||
sessionHandler.EXPECT().CloseServer()
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Expect(conn.closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores packets for closed sessions", func() {
|
||||
sessionHandler.EXPECT().Get(connID).Return(nil, true)
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("works if no quic.Config is given", func(done Done) {
|
||||
|
@ -264,163 +244,56 @@ var _ = Describe("Server", func() {
|
|||
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var returned bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := ln.Accept()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
|
||||
returned = true
|
||||
}()
|
||||
ln.Close()
|
||||
Eventually(func() bool { return returned }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("errors when encountering a connection error", func() {
|
||||
testErr := errors.New("connection error")
|
||||
conn.readErr = testErr
|
||||
sessionHandler.EXPECT().Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.serve()
|
||||
ln.Accept()
|
||||
close(done)
|
||||
}()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
ln.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("ignores delayed packets with mismatching versions", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().GetVersion()
|
||||
// don't EXPECT any handlePacket() calls to this session
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
// add an unsupported version
|
||||
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1))
|
||||
data = append(append(data, b.Bytes()...), 0x01)
|
||||
err := serv.handlePacket(nil, data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
|
||||
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
|
||||
It("returns Accept when it is closed", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError("server closed"))
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().CloseServer()
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors on invalid public header", func() {
|
||||
err := serv.handlePacket(nil, nil)
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
|
||||
})
|
||||
|
||||
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
serv.supportsTLS = true
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
PayloadLen: 1000,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
||||
Expect(err).To(MatchError("packet payload (456 bytes) is smaller than the expected payload length (1000 bytes)"))
|
||||
})
|
||||
|
||||
It("cuts packets at the payload length", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
|
||||
Expect(packet.data).To(HaveLen(123))
|
||||
})
|
||||
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
serv.supportsTLS = true
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
PayloadLen: 123,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("drops packets with invalid packet types", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
serv.supportsTLS = true
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
PayloadLen: 123,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
||||
Expect(err).To(MatchError("Received unsupported packet type: Retry"))
|
||||
})
|
||||
|
||||
It("ignores Public Resets", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
It("returns Accept with the right error when closeWithError is called", func() {
|
||||
testErr := errors.New("connection error")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().CloseServer()
|
||||
serv.closeWithError(testErr)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
|
||||
config.Versions = []protocol.VersionNumber{99}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
VersionFlag: true,
|
||||
DestConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
p := &receivedPacket{
|
||||
header: &wire.Header{
|
||||
VersionFlag: true,
|
||||
DestConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
},
|
||||
data: make([]byte, protocol.MinClientHelloSize),
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
|
||||
serv.conn = conn
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
err := serv.handlePacket(nil, b.Bytes())
|
||||
Expect(serv.handlePacketImpl(p)).To(Succeed())
|
||||
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("doesn't respond with a version negotiation packet if the first packet is too small", func() {
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
VersionFlag: true,
|
||||
DestConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small
|
||||
serv.conn = conn
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
err := serv.handlePacket(udpAddr, b.Bytes())
|
||||
Expect(err).To(MatchError("dropping small packet with unknown version"))
|
||||
Expect(conn.dataWritten.Len()).Should(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -523,8 +396,11 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
|
||||
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
config.Versions = append(config.Versions, protocol.VersionTLS)
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
|
@ -536,13 +412,10 @@ var _ = Describe("Server", func() {
|
|||
Version: 0x1234,
|
||||
PayloadLen: protocol.MinInitialPacketSize,
|
||||
}
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)).To(Succeed())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) // add a fake CHLO
|
||||
conn.dataToRead <- b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -568,51 +441,6 @@ var _ = Describe("Server", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
|
||||
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
IsLongHeader: true,
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
PacketNumber: 0x55,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
|
||||
conn.dataToRead <- b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
|
||||
})
|
||||
|
||||
It("ignores non-Initial Long Header packets for unknown connections", func() {
|
||||
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
IsLongHeader: true,
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
PacketNumber: 0x55,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.dataToRead <- b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
|
||||
})
|
||||
|
||||
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
|
||||
conn.dataReadFrom = udpAddr
|
||||
conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue