move listening from the multiplexer to the packet handler map

This commit is contained in:
Marten Seemann 2018-07-17 12:24:11 -04:00
parent 7e2adfe19d
commit c8d20e86d7
8 changed files with 255 additions and 291 deletions

View file

@ -19,14 +19,15 @@ import (
type client struct {
mutex sync.Mutex
pconn net.PacketConn
conn connection
conn connection
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
hostname string
packetHandlers packetHandlerManager
receivedRetry bool
versionNegotiated bool // has the server accepted our version
@ -123,22 +124,18 @@ func dialContext(
createdPacketConn bool,
) (Session, error) {
config = populateClientConfig(config, createdPacketConn)
multiplexer := getMultiplexer()
manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove, createdPacketConn)
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
if err != nil {
return nil, err
}
if err := multiplexer.AddHandler(pconn, c.srcConnID, c); err != nil {
return nil, err
}
c.packetHandlers = packetHandlers
c.packetHandlers.Add(c.srcConnID, c)
if config.RequestConnectionIDOmission {
if err := multiplexer.AddHandler(pconn, protocol.ConnectionID{}, c); err != nil {
return nil, err
}
c.packetHandlers.Add(protocol.ConnectionID{}, c)
}
if err := c.dial(ctx); err != nil {
return nil, err
@ -180,7 +177,6 @@ func newClient(
onClose = closeCallback
}
c := &client{
pconn: pconn,
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
hostname: hostname,
@ -484,9 +480,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
c.initialVersion = c.version
c.version = newVersion
c.generateConnectionIDs()
if err := getMultiplexer().AddHandler(c.pconn, c.srcConnID, c); err != nil {
return err
}
c.packetHandlers.Add(c.srcConnID, c)
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)

View file

@ -107,8 +107,8 @@ var _ = Describe("Client", func() {
It("resolves the address", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
if os.Getenv("APPVEYOR") == "True" {
Skip("This test is flaky on AppVeyor.")
@ -138,8 +138,8 @@ var _ = Describe("Client", func() {
It("uses the tls.Config.ServerName as the hostname, if present", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
hostnameChan := make(chan string, 1)
newClientSession = func(
@ -166,8 +166,8 @@ var _ = Describe("Client", func() {
It("returns after the handshake is complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
run := make(chan struct{})
newClientSession = func(
@ -195,8 +195,8 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the connection to become secure", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
testErr := errors.New("early handshake error")
newClientSession = func(
@ -222,8 +222,8 @@ var _ = Describe("Client", func() {
It("closes the session when the context is canceled", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
sessionRunning := make(chan struct{})
defer close(sessionRunning)
@ -261,9 +261,9 @@ var _ = Describe("Client", func() {
It("removes closed sessions from the multiplexer", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
manager.EXPECT().Remove(connID)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
var runner sessionRunner
sess := NewMockQuicSession(mockCtrl)
@ -293,7 +293,7 @@ var _ = Describe("Client", func() {
It("closes the connection when it was created by DialAddr", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
manager.EXPECT().Add(gomock.Any(), gomock.Any())
var conn connection
run := make(chan struct{})
@ -414,8 +414,8 @@ var _ = Describe("Client", func() {
Context("gQUIC", func() {
It("errors if it can't create a session", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
testErr := errors.New("error creating session")
newClientSession = func(
@ -440,8 +440,8 @@ var _ = Describe("Client", func() {
Context("IETF QUIC", func() {
It("creates new TLS sessions with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
c := make(chan struct{})
@ -496,8 +496,8 @@ var _ = Describe("Client", func() {
It("returns an error that occurs during version negotiation", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
testErr := errors.New("early handshake error")
newClientSession = func(
@ -537,7 +537,9 @@ var _ = Describe("Client", func() {
})
It("changes the version after receiving a version negotiation packet", func() {
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
phm := NewMockPacketHandlerManager(mockCtrl)
phm.EXPECT().Add(connID, gomock.Any())
cl.packetHandlers = phm
version1 := protocol.Version39
version2 := protocol.Version39 + 1
@ -580,8 +582,9 @@ var _ = Describe("Client", func() {
})
It("only accepts one version negotiation packet", func() {
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
phm := NewMockPacketHandlerManager(mockCtrl)
phm.EXPECT().Add(connID, gomock.Any())
cl.packetHandlers = phm
version1 := protocol.Version39
version2 := protocol.Version39 + 1
version3 := protocol.Version39 + 2
@ -647,7 +650,10 @@ var _ = Describe("Client", func() {
})
It("changes to the version preferred by the quic.Config", func() {
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
phm := NewMockPacketHandlerManager(mockCtrl)
phm.EXPECT().Add(connID, gomock.Any())
cl.packetHandlers = phm
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().destroy(errCloseSessionForNewVersion)
cl.session = sess
@ -726,8 +732,8 @@ var _ = Describe("Client", func() {
It("creates new gQUIC sessions with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
config := &Config{Versions: protocol.SupportedVersions}
c := make(chan struct{})
@ -740,7 +746,7 @@ var _ = Describe("Client", func() {
_ sessionRunner,
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
connIDP protocol.ConnectionID,
_ *tls.Config,
configP *Config,
_ protocol.VersionNumber,
@ -751,11 +757,13 @@ var _ = Describe("Client", func() {
hostname = hostnameP
version = versionP
conf = configP
connID = connIDP
close(c)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
return sess, nil
}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
@ -767,8 +775,8 @@ var _ = Describe("Client", func() {
It("creates a new session when the server performs a retry", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
cl.config = config
@ -801,8 +809,8 @@ var _ = Describe("Client", func() {
It("only accepts one Retry packet", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
sess1 := NewMockQuicSession(mockCtrl)

View file

@ -9,7 +9,6 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockMultiplexer is a mock of Multiplexer interface
@ -47,15 +46,3 @@ func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int) (packetHandlerM
func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1)
}
// AddHandler mocks base method
func (m *MockMultiplexer) AddHandler(arg0 net.PacketConn, arg1 protocol.ConnectionID, arg2 packetHandler) error {
ret := m.ctrl.Call(m, "AddHandler", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// AddHandler indicates an expected call of AddHandler
func (mr *MockMultiplexerMockRecorder) AddHandler(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddHandler", reflect.TypeOf((*MockMultiplexer)(nil).AddHandler), arg0, arg1, arg2)
}

View file

@ -1,17 +1,11 @@
package quic
import (
"bytes"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
var (
@ -21,7 +15,6 @@ var (
type multiplexer interface {
AddConn(net.PacketConn, int) (packetHandlerManager, error)
AddHandler(net.PacketConn, protocol.ConnectionID, packetHandler) error
}
type connManager struct {
@ -35,7 +28,7 @@ type connMultiplexer struct {
mutex sync.Mutex
conns map[net.PacketConn]connManager
newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests
newPacketHandlerManager func(net.PacketConn, int, utils.Logger, bool) packetHandlerManager // so it can be replaced in the tests
logger utils.Logger
}
@ -59,89 +52,12 @@ func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandle
p, ok := m.conns[c]
if !ok {
manager := m.newPacketHandlerManager()
manager := m.newPacketHandlerManager(c, connIDLen, m.logger, true)
p = connManager{connIDLen: connIDLen, manager: manager}
m.conns[c] = p
// If we didn't know this packet conn before, listen for incoming packets
// and dispatch them to the right sessions.
go m.listen(c, &p)
}
if p.connIDLen != connIDLen {
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
}
return p.manager, nil
}
func (m *connMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error {
m.mutex.Lock()
defer m.mutex.Unlock()
p, ok := m.conns[c]
if !ok {
return errors.New("unknown packet conn %s")
}
p.manager.Add(connID, handler)
return nil
}
func (m *connMultiplexer) listen(c net.PacketConn, p *connManager) {
for {
data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err := c.ReadFrom(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
p.manager.Close()
}
return
}
data = data[:n]
if err := m.handlePacket(addr, data, p); err != nil {
m.logger.Debugf("error handling packet from %s: %s", addr, err)
}
}
}
func (m *connMultiplexer) handlePacket(addr net.Addr, data []byte, p *connManager) error {
rcvTime := time.Now()
r := bytes.NewReader(data)
iHdr, err := wire.ParseInvariantHeader(r, p.connIDLen)
// drop the packet if we can't parse the header
if err != nil {
return fmt.Errorf("error parsing invariant header: %s", err)
}
handler, ok := p.manager.Get(iHdr.DestConnectionID)
if !ok {
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
if handler == nil {
// Late packet for closed session
return nil
}
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, handler.GetVersion())
if err != nil {
return fmt.Errorf("error parsing header: %s", err)
}
hdr.Raw = data[:len(data)-r.Len()]
packetData := data[len(data)-r.Len():]
if hdr.IsLongHeader {
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
}
packetData = packetData[:int(hdr.PayloadLen)]
// TODO(#1312): implement parsing of compound packets
}
handler.handlePacket(&receivedPacket{
remoteAddr: addr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
return nil
}

View file

@ -1,46 +1,15 @@
package quic
import (
"bytes"
"errors"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Client Multiplexer", func() {
getPacket := func(connID protocol.ConnectionID) []byte {
buf := &bytes.Buffer{}
err := (&wire.Header{
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
return buf.Bytes()
}
It("adds a new packet conn and handles packets", func() {
It("adds a new packet conn ", func() {
conn := newMockPacketConn()
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockQuicSession(mockCtrl)
handledPacket := make(chan struct{})
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.header.DestConnectionID).To(Equal(connID))
close(handledPacket)
})
packetHandler.EXPECT().GetVersion()
getMultiplexer().AddConn(conn, 8)
err := getMultiplexer().AddHandler(conn, connID, packetHandler)
_, err := getMultiplexer().AddConn(conn, 8)
Expect(err).ToNot(HaveOccurred())
conn.dataToRead <- getPacket(connID)
Eventually(handledPacket).Should(BeClosed())
// makes the listen go routine return
packetHandler.EXPECT().Close().AnyTimes()
close(conn.dataToRead)
})
It("errors when adding an existing conn with a different connection ID length", func() {
@ -51,124 +20,4 @@ var _ = Describe("Client Multiplexer", func() {
Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs"))
})
It("errors when adding a handler for an unknown conn", func() {
conn := newMockPacketConn()
err := getMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4}, NewMockQuicSession(mockCtrl))
Expect(err).ToNot(MatchError("unknown packet conn"))
})
It("handles packets for different packet handlers on the same packet conn", func() {
conn := newMockPacketConn()
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
packetHandler1 := NewMockQuicSession(mockCtrl)
packetHandler2 := NewMockQuicSession(mockCtrl)
handledPacket1 := make(chan struct{})
handledPacket2 := make(chan struct{})
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.header.DestConnectionID).To(Equal(connID1))
close(handledPacket1)
})
packetHandler1.EXPECT().GetVersion()
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.header.DestConnectionID).To(Equal(connID2))
close(handledPacket2)
})
packetHandler2.EXPECT().GetVersion()
getMultiplexer().AddConn(conn, connID1.Len())
Expect(getMultiplexer().AddHandler(conn, connID1, packetHandler1)).To(Succeed())
Expect(getMultiplexer().AddHandler(conn, connID2, packetHandler2)).To(Succeed())
conn.dataToRead <- getPacket(connID1)
conn.dataToRead <- getPacket(connID2)
Eventually(handledPacket1).Should(BeClosed())
Eventually(handledPacket2).Should(BeClosed())
// makes the listen go routine return
packetHandler1.EXPECT().Close().AnyTimes()
packetHandler2.EXPECT().Close().AnyTimes()
close(conn.dataToRead)
})
It("drops unparseable packets", func() {
err := getMultiplexer().(*connMultiplexer).handlePacket(nil, []byte("invalid"), &connManager{connIDLen: 8})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error parsing invariant header:"))
})
It("ignores packets arriving late for closed sessions", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Get(connID).Return(nil, true)
err := getMultiplexer().(*connMultiplexer).handlePacket(nil, getPacket(connID), &connManager{manager: manager, connIDLen: 8})
Expect(err).ToNot(HaveOccurred())
})
It("drops packets for unknown receivers", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Get(connID).Return(nil, false)
err := getMultiplexer().(*connMultiplexer).handlePacket(nil, getPacket(connID), &connManager{manager: manager, connIDLen: 8})
Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
})
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 1000,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
buf := &bytes.Buffer{}
Expect(hdr.Write(buf, protocol.PerspectiveServer, versionIETFFrames)).To(Succeed())
buf.Write(bytes.Repeat([]byte{0}, 500))
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion().Return(versionIETFFrames)
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Get(connID).Return(sess, true)
err := getMultiplexer().(*connMultiplexer).handlePacket(nil, buf.Bytes(), &connManager{manager: manager, connIDLen: 8})
Expect(err).To(MatchError("packet payload (500 bytes) is smaller than the expected payload length (1000 bytes)"))
})
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 456,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
buf := &bytes.Buffer{}
Expect(hdr.Write(buf, protocol.PerspectiveServer, versionIETFFrames)).To(Succeed())
buf.Write(bytes.Repeat([]byte{0}, 500))
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion().Return(versionIETFFrames)
sess.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.data).To(HaveLen(456))
})
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Get(connID).Return(sess, true)
err := getMultiplexer().(*connMultiplexer).handlePacket(nil, buf.Bytes(), &connManager{manager: manager, connIDLen: 8})
Expect(err).ToNot(HaveOccurred())
})
It("closes the packet handlers when reading from the conn fails", func() {
conn := newMockPacketConn()
conn.readErr = errors.New("test error")
done := make(chan struct{})
packetHandler := NewMockQuicSession(mockCtrl)
packetHandler.EXPECT().Close().Do(func() {
close(done)
})
getMultiplexer().AddConn(conn, 8)
Expect(getMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)).To(Succeed())
Eventually(done).Should(BeClosed())
})
})

View file

@ -1,10 +1,16 @@
package quic
import (
"bytes"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// The packetHandlerMap stores packetHandlers, identified by connection ID.
@ -14,19 +20,32 @@ import (
type packetHandlerMap struct {
mutex sync.RWMutex
conn net.PacketConn
connIDLen int
handlers map[string] /* string(ConnectionID)*/ packetHandler
closed bool
deleteClosedSessionsAfter time.Duration
logger utils.Logger
}
var _ packetHandlerManager = &packetHandlerMap{}
func newPacketHandlerMap() packetHandlerManager {
return &packetHandlerMap{
// TODO(#561): remove the listen flag
func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger, listen bool) packetHandlerManager {
m := &packetHandlerMap{
conn: conn,
connIDLen: connIDLen,
handlers: make(map[string]packetHandler),
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
logger: logger,
}
if listen {
go m.listen()
}
return m
}
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
@ -77,3 +96,65 @@ func (h *packetHandlerMap) Close() error {
wg.Wait()
return nil
}
func (h *packetHandlerMap) listen() {
for {
data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err := h.conn.ReadFrom(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
h.Close()
}
return
}
data = data[:n]
if err := h.handlePacket(addr, data); err != nil {
h.logger.Debugf("error handling packet from %s: %s", addr, err)
}
}
}
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
rcvTime := time.Now()
r := bytes.NewReader(data)
iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
// drop the packet if we can't parse the header
if err != nil {
return fmt.Errorf("error parsing invariant header: %s", err)
}
handler, ok := h.Get(iHdr.DestConnectionID)
if !ok {
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
if handler == nil {
// Late packet for closed session
return nil
}
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, handler.GetVersion())
if err != nil {
return fmt.Errorf("error parsing header: %s", err)
}
hdr.Raw = data[:len(data)-r.Len()]
packetData := data[len(data)-r.Len():]
if hdr.IsLongHeader {
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
}
packetData = packetData[:int(hdr.PayloadLen)]
// TODO(#1312): implement parsing of compound packets
}
handler.handlePacket(&receivedPacket{
remoteAddr: addr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
return nil
}

View file

@ -1,18 +1,26 @@
package quic
import (
"bytes"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Packet Handler Map", func() {
var handler *packetHandlerMap
var (
handler *packetHandlerMap
conn *mockPacketConn
)
BeforeEach(func() {
handler = newPacketHandlerMap().(*packetHandlerMap)
conn = newMockPacketConn()
handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger, true).(*packetHandlerMap)
})
It("adds and gets", func() {
@ -53,4 +61,124 @@ var _ = Describe("Packet Handler Map", func() {
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2)
handler.Close()
})
Context("handling packets", func() {
getPacket := func(connID protocol.ConnectionID) []byte {
buf := &bytes.Buffer{}
err := (&wire.Header{
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
return buf.Bytes()
}
It("handles packets for different packet handlers on the same packet conn", func() {
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
packetHandler1 := NewMockQuicSession(mockCtrl)
packetHandler2 := NewMockQuicSession(mockCtrl)
handledPacket1 := make(chan struct{})
handledPacket2 := make(chan struct{})
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.header.DestConnectionID).To(Equal(connID1))
close(handledPacket1)
})
packetHandler1.EXPECT().GetVersion()
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.header.DestConnectionID).To(Equal(connID2))
close(handledPacket2)
})
packetHandler2.EXPECT().GetVersion()
handler.Add(connID1, packetHandler1)
handler.Add(connID2, packetHandler2)
conn.dataToRead <- getPacket(connID1)
conn.dataToRead <- getPacket(connID2)
Eventually(handledPacket1).Should(BeClosed())
Eventually(handledPacket2).Should(BeClosed())
// makes the listen go routine return
packetHandler1.EXPECT().Close().AnyTimes()
packetHandler2.EXPECT().Close().AnyTimes()
close(conn.dataToRead)
})
It("drops unparseable packets", func() {
err := handler.handlePacket(nil, []byte("invalid"))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error parsing invariant header:"))
})
It("ignores packets arriving late for closed sessions", func() {
handler.deleteClosedSessionsAfter = time.Hour
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.Add(connID, NewMockQuicSession(mockCtrl))
handler.Remove(connID)
err := handler.handlePacket(nil, getPacket(connID))
Expect(err).ToNot(HaveOccurred())
})
It("drops packets for unknown receivers", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
err := handler.handlePacket(nil, getPacket(connID))
Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
})
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockQuicSession(mockCtrl)
packetHandler.EXPECT().GetVersion().Return(versionIETFFrames)
handler.Add(connID, packetHandler)
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 1000,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
buf := &bytes.Buffer{}
Expect(hdr.Write(buf, protocol.PerspectiveServer, versionIETFFrames)).To(Succeed())
buf.Write(bytes.Repeat([]byte{0}, 500))
err := handler.handlePacket(nil, buf.Bytes())
Expect(err).To(MatchError("packet payload (500 bytes) is smaller than the expected payload length (1000 bytes)"))
})
It("cuts packets at the Payload Length", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockQuicSession(mockCtrl)
packetHandler.EXPECT().GetVersion().Return(versionIETFFrames)
handler.Add(connID, packetHandler)
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.data).To(HaveLen(456))
})
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 456,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
buf := &bytes.Buffer{}
Expect(hdr.Write(buf, protocol.PerspectiveServer, versionIETFFrames)).To(Succeed())
buf.Write(bytes.Repeat([]byte{0}, 500))
err := handler.handlePacket(nil, buf.Bytes())
Expect(err).ToNot(HaveOccurred())
})
It("closes the packet handlers when reading from the conn fails", func() {
done := make(chan struct{})
packetHandler := NewMockQuicSession(mockCtrl)
packetHandler.EXPECT().Close().Do(func() {
close(done)
})
handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
conn.Close()
Eventually(done).Should(BeClosed())
})
})
})

View file

@ -125,6 +125,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
}
}
logger := utils.DefaultLogger.WithPrefix("server")
s := &server{
conn: conn,
tlsConf: tlsConf,
@ -132,11 +133,11 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
certChain: certChain,
scfg: scfg,
newSession: newSession,
sessionHandler: newPacketHandlerMap(),
sessionHandler: newPacketHandlerMap(conn, config.ConnectionIDLength, logger, false),
sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}),
supportsTLS: supportsTLS,
logger: utils.DefaultLogger.WithPrefix("server"),
logger: logger,
}
s.setup()
if supportsTLS {