pass a logging.Tracer to the packet handler map

This commit is contained in:
Marten Seemann 2020-07-10 15:23:24 +07:00
parent dc245ca6a3
commit 2f63bc0731
9 changed files with 73 additions and 48 deletions

View file

@ -165,7 +165,7 @@ func dialContext(
return nil, errors.New("quic: tls.Config not set") return nil, errors.New("quic: tls.Config not set")
} }
config = populateClientConfig(config, createdPacketConn) config = populateClientConfig(config, createdPacketConn)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey) packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -131,7 +131,7 @@ var _ = Describe("Client", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
manager.EXPECT().Destroy() manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
remoteAddrChan := make(chan string, 1) remoteAddrChan := make(chan string, 1)
newClientSession = func( newClientSession = func(
@ -164,7 +164,7 @@ var _ = Describe("Client", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
manager.EXPECT().Destroy() manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
hostnameChan := make(chan string, 1) hostnameChan := make(chan string, 1)
newClientSession = func( newClientSession = func(
@ -197,7 +197,7 @@ var _ = Describe("Client", func() {
It("allows passing host without port as server name", func() { It("allows passing host without port as server name", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
hostnameChan := make(chan string, 1) hostnameChan := make(chan string, 1)
newClientSession = func( newClientSession = func(
@ -236,7 +236,7 @@ var _ = Describe("Client", func() {
It("returns after the handshake is complete", func() { It("returns after the handshake is complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
run := make(chan struct{}) run := make(chan struct{})
newClientSession = func( newClientSession = func(
@ -278,7 +278,7 @@ var _ = Describe("Client", func() {
It("returns early sessions", func() { It("returns early sessions", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
readyChan := make(chan struct{}) readyChan := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
@ -327,7 +327,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the handshake to complete", func() { It("returns an error that occurs while waiting for the handshake to complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
testErr := errors.New("early handshake error") testErr := errors.New("early handshake error")
newClientSession = func( newClientSession = func(
@ -365,7 +365,7 @@ var _ = Describe("Client", func() {
It("closes the session when the context is canceled", func() { It("closes the session when the context is canceled", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
sessionRunning := make(chan struct{}) sessionRunning := make(chan struct{})
defer close(sessionRunning) defer close(sessionRunning)
@ -419,7 +419,7 @@ var _ = Describe("Client", func() {
} }
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
var conn connection var conn connection
@ -497,7 +497,7 @@ var _ = Describe("Client", func() {
It("errors when the Config contains an invalid version", func() { It("errors when the Config contains an invalid version", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
version := protocol.VersionNumber(0x1234) version := protocol.VersionNumber(0x1234)
_, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) _, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
@ -540,7 +540,7 @@ var _ = Describe("Client", func() {
It("creates new sessions with the right parameters", func() { It("creates new sessions with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any()) manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
c := make(chan struct{}) c := make(chan struct{})
@ -584,7 +584,7 @@ var _ = Describe("Client", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any()).Times(2) manager.EXPECT().Add(connID, gomock.Any()).Times(2)
manager.EXPECT().Destroy() manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
initialVersion := cl.version initialVersion := cl.version

View file

@ -19,10 +19,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/qlog" "github.com/lucas-clemente/quic-go/qlog"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -91,6 +90,19 @@ var (
tlsConfig *tls.Config tlsConfig *tls.Config
tlsConfigLongChain *tls.Config tlsConfigLongChain *tls.Config
tlsClientConfig *tls.Config tlsClientConfig *tls.Config
tracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename)
f, err := os.Create(filename)
Expect(err).ToNot(HaveOccurred())
bw := bufio.NewWriter(f)
return utils.NewBufferedWriteCloser(bw, f)
})
) )
// read the logfile command line flag // read the logfile command line flag
@ -254,18 +266,7 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
if !enableQlog { if !enableQlog {
return conf return conf
} }
conf.Tracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser { conf.Tracer = tracer
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename)
f, err := os.Create(filename)
Expect(err).ToNot(HaveOccurred())
bw := bufio.NewWriter(f)
return utils.NewBufferedWriteCloser(bw, f)
})
return conf return conf
} }

View file

@ -9,6 +9,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
logging "github.com/lucas-clemente/quic-go/logging"
) )
// MockMultiplexer is a mock of Multiplexer interface // MockMultiplexer is a mock of Multiplexer interface
@ -35,18 +36,18 @@ func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder {
} }
// AddConn mocks base method // AddConn mocks base method
func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 []byte) (packetHandlerManager, error) { func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 []byte, arg3 logging.Tracer) (packetHandlerManager, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2) ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(packetHandlerManager) ret0, _ := ret[0].(packetHandlerManager)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// AddConn indicates an expected call of AddConn // AddConn indicates an expected call of AddConn
func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2 interface{}) *gomock.Call { func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2, arg3)
} }
// RemoveConn mocks base method // RemoveConn mocks base method

View file

@ -7,6 +7,7 @@ import (
"sync" "sync"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging"
) )
var ( var (
@ -15,13 +16,14 @@ var (
) )
type multiplexer interface { type multiplexer interface {
AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte) (packetHandlerManager, error) AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error)
RemoveConn(net.PacketConn) error RemoveConn(net.PacketConn) error
} }
type connManager struct { type connManager struct {
connIDLen int connIDLen int
statelessResetKey []byte statelessResetKey []byte
tracer logging.Tracer
manager packetHandlerManager manager packetHandlerManager
} }
@ -31,7 +33,7 @@ type connMultiplexer struct {
mutex sync.Mutex mutex sync.Mutex
conns map[string] /* LocalAddr().String() */ connManager conns map[string] /* LocalAddr().String() */ connManager
newPacketHandlerManager func(net.PacketConn, int, []byte, utils.Logger) packetHandlerManager // so it can be replaced in the tests newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) packetHandlerManager // so it can be replaced in the tests
logger utils.Logger logger utils.Logger
} }
@ -53,6 +55,7 @@ func (m *connMultiplexer) AddConn(
c net.PacketConn, c net.PacketConn,
connIDLen int, connIDLen int,
statelessResetKey []byte, statelessResetKey []byte,
tracer logging.Tracer,
) (packetHandlerManager, error) { ) (packetHandlerManager, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -60,19 +63,24 @@ func (m *connMultiplexer) AddConn(
connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String()
p, ok := m.conns[connIndex] p, ok := m.conns[connIndex]
if !ok { if !ok {
manager := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, m.logger) manager := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)
p = connManager{ p = connManager{
connIDLen: connIDLen, connIDLen: connIDLen,
statelessResetKey: statelessResetKey, statelessResetKey: statelessResetKey,
manager: manager, manager: manager,
tracer: tracer,
} }
m.conns[connIndex] = p m.conns[connIndex] = p
} } else {
if p.connIDLen != connIDLen { if p.connIDLen != connIDLen {
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
} }
if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) {
return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn")
}
if tracer != p.tracer {
return nil, fmt.Errorf("cannot use different tracers on the same packet conn")
}
} }
return p.manager, nil return p.manager, nil
} }

View file

@ -3,6 +3,8 @@ package quic
import ( import (
"net" "net"
"github.com/lucas-clemente/quic-go/internal/mocks"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -15,7 +17,7 @@ type testConn struct {
var _ = Describe("Client Multiplexer", func() { var _ = Describe("Client Multiplexer", func() {
It("adds a new packet conn ", func() { It("adds a new packet conn ", func() {
conn := newMockPacketConn() conn := newMockPacketConn()
_, err := getMultiplexer().AddConn(conn, 8, nil) _, err := getMultiplexer().AddConn(conn, 8, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -23,27 +25,36 @@ var _ = Describe("Client Multiplexer", func() {
pconn := newMockPacketConn() pconn := newMockPacketConn()
pconn.addr = &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321} pconn.addr = &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
conn := testConn{PacketConn: pconn} conn := testConn{PacketConn: pconn}
_, err := getMultiplexer().AddConn(conn, 8, nil) tracer := mocks.NewMockTracer(mockCtrl)
_, err := getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
conn.counter++ conn.counter++
_, err = getMultiplexer().AddConn(conn, 8, nil) _, err = getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1)) Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1))
}) })
It("errors when adding an existing conn with a different connection ID length", func() { It("errors when adding an existing conn with a different connection ID length", func() {
conn := newMockPacketConn() conn := newMockPacketConn()
_, err := getMultiplexer().AddConn(conn, 5, nil) _, err := getMultiplexer().AddConn(conn, 5, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 6, nil) _, err = getMultiplexer().AddConn(conn, 6, nil, nil)
Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs"))
}) })
It("errors when adding an existing conn with a different stateless rest key", func() { It("errors when adding an existing conn with a different stateless rest key", func() {
conn := newMockPacketConn() conn := newMockPacketConn()
_, err := getMultiplexer().AddConn(conn, 7, []byte("foobar")) _, err := getMultiplexer().AddConn(conn, 7, []byte("foobar"), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 7, []byte("raboof")) _, err = getMultiplexer().AddConn(conn, 7, []byte("raboof"), nil)
Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn")) Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn"))
}) })
It("errors when adding an existing conn with different tracers", func() {
conn := newMockPacketConn()
_, err := getMultiplexer().AddConn(conn, 7, nil, mocks.NewMockTracer(mockCtrl))
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 7, nil, mocks.NewMockTracer(mockCtrl))
Expect(err).To(MatchError("cannot use different tracers on the same packet conn"))
})
}) })

View file

@ -13,6 +13,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/logging"
) )
type statelessResetErr struct { type statelessResetErr struct {
@ -46,6 +47,7 @@ type packetHandlerMap struct {
statelessResetMutex sync.Mutex statelessResetMutex sync.Mutex
statelessResetHasher hash.Hash statelessResetHasher hash.Hash
tracer logging.Tracer
logger utils.Logger logger utils.Logger
} }
@ -55,6 +57,7 @@ func newPacketHandlerMap(
conn net.PacketConn, conn net.PacketConn,
connIDLen int, connIDLen int,
statelessResetKey []byte, statelessResetKey []byte,
tracer logging.Tracer,
logger utils.Logger, logger utils.Logger,
) packetHandlerManager { ) packetHandlerManager {
m := &packetHandlerMap{ m := &packetHandlerMap{
@ -66,6 +69,7 @@ func newPacketHandlerMap(
deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
statelessResetEnabled: len(statelessResetKey) > 0, statelessResetEnabled: len(statelessResetKey) > 0,
statelessResetHasher: hmac.New(sha256.New, statelessResetKey), statelessResetHasher: hmac.New(sha256.New, statelessResetKey),
tracer: tracer,
logger: logger, logger: logger,
} }
go m.listen() go m.listen()

View file

@ -50,7 +50,7 @@ var _ = Describe("Packet Handler Map", func() {
JustBeforeEach(func() { JustBeforeEach(func() {
conn = newMockPacketConn() conn = newMockPacketConn()
handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, utils.DefaultLogger).(*packetHandlerMap) handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, nil, utils.DefaultLogger).(*packetHandlerMap)
}) })
AfterEach(func() { AfterEach(func() {

View file

@ -178,7 +178,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
} }
} }
sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey) sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
if err != nil { if err != nil {
return nil, err return nil, err
} }