also use the multiplexer for the server

This commit is contained in:
Marten Seemann 2018-07-20 08:26:36 -04:00
parent c8d20e86d7
commit ad5a3e2fa0
15 changed files with 631 additions and 512 deletions

View file

@ -544,9 +544,22 @@ func (c *client) Close() error {
return c.session.Close()
}
func (c *client) destroy(e error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return
}
c.session.destroy(e)
}
func (c *client) GetVersion() protocol.VersionNumber {
c.mutex.Lock()
v := c.version
c.mutex.Unlock()
return v
}
func (c *client) GetPerspective() protocol.Perspective {
return protocol.PerspectiveClient
}

View file

@ -9,6 +9,11 @@ const (
PerspectiveClient Perspective = 2
)
// Opposite returns the perspective of the peer
func (p Perspective) Opposite() Perspective {
return 3 - p
}
func (p Perspective) String() string {
switch p {
case PerspectiveServer:

View file

@ -11,4 +11,9 @@ var _ = Describe("Perspective", func() {
Expect(PerspectiveServer.String()).To(Equal("Server"))
Expect(Perspective(0).String()).To(Equal("invalid perspective"))
})
It("returns the opposite", func() {
Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer))
Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient))
})
})

View file

@ -44,29 +44,14 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
}
// Close mocks base method
func (m *MockPacketHandlerManager) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
// CloseServer mocks base method
func (m *MockPacketHandlerManager) CloseServer() {
m.ctrl.Call(m, "CloseServer")
}
// Close indicates an expected call of Close
func (mr *MockPacketHandlerManagerMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close))
}
// Get mocks base method
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].(packetHandler)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// Get indicates an expected call of Get
func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
// CloseServer indicates an expected call of CloseServer
func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
}
// Remove mocks base method
@ -78,3 +63,13 @@ func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0)
}
// SetServer mocks base method
func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) {
m.ctrl.Call(m, "SetServer", arg0)
}
// SetServer indicates an expected call of SetServer
func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0)
}

View file

@ -0,0 +1,91 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: PacketHandler)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockPacketHandler is a mock of PacketHandler interface
type MockPacketHandler struct {
ctrl *gomock.Controller
recorder *MockPacketHandlerMockRecorder
}
// MockPacketHandlerMockRecorder is the mock recorder for MockPacketHandler
type MockPacketHandlerMockRecorder struct {
mock *MockPacketHandler
}
// NewMockPacketHandler creates a new mock instance
func NewMockPacketHandler(ctrl *gomock.Controller) *MockPacketHandler {
mock := &MockPacketHandler{ctrl: ctrl}
mock.recorder = &MockPacketHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder {
return m.recorder
}
// Close mocks base method
func (m *MockPacketHandler) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockPacketHandlerMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandler)(nil).Close))
}
// GetPerspective mocks base method
func (m *MockPacketHandler) GetPerspective() protocol.Perspective {
ret := m.ctrl.Call(m, "GetPerspective")
ret0, _ := ret[0].(protocol.Perspective)
return ret0
}
// GetPerspective indicates an expected call of GetPerspective
func (mr *MockPacketHandlerMockRecorder) GetPerspective() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPerspective", reflect.TypeOf((*MockPacketHandler)(nil).GetPerspective))
}
// GetVersion mocks base method
func (m *MockPacketHandler) GetVersion() protocol.VersionNumber {
ret := m.ctrl.Call(m, "GetVersion")
ret0, _ := ret[0].(protocol.VersionNumber)
return ret0
}
// GetVersion indicates an expected call of GetVersion
func (mr *MockPacketHandlerMockRecorder) GetVersion() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockPacketHandler)(nil).GetVersion))
}
// destroy mocks base method
func (m *MockPacketHandler) destroy(arg0 error) {
m.ctrl.Call(m, "destroy", arg0)
}
// destroy indicates an expected call of destroy
func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0)
}
// handlePacket mocks base method
func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) {
m.ctrl.Call(m, "handlePacket", arg0)
}
// handlePacket indicates an expected call of handlePacket
func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0)
}

View file

@ -0,0 +1,56 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: UnknownPacketHandler)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockUnknownPacketHandler is a mock of UnknownPacketHandler interface
type MockUnknownPacketHandler struct {
ctrl *gomock.Controller
recorder *MockUnknownPacketHandlerMockRecorder
}
// MockUnknownPacketHandlerMockRecorder is the mock recorder for MockUnknownPacketHandler
type MockUnknownPacketHandlerMockRecorder struct {
mock *MockUnknownPacketHandler
}
// NewMockUnknownPacketHandler creates a new mock instance
func NewMockUnknownPacketHandler(ctrl *gomock.Controller) *MockUnknownPacketHandler {
mock := &MockUnknownPacketHandler{ctrl: ctrl}
mock.recorder = &MockUnknownPacketHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorder {
return m.recorder
}
// closeWithError mocks base method
func (m *MockUnknownPacketHandler) closeWithError(arg0 error) error {
ret := m.ctrl.Call(m, "closeWithError", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// closeWithError indicates an expected call of closeWithError
func (mr *MockUnknownPacketHandlerMockRecorder) closeWithError(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).closeWithError), arg0)
}
// handlePacket mocks base method
func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) {
m.ctrl.Call(m, "handlePacket", arg0)
}
// handlePacket indicates an expected call of handlePacket
func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0)
}

View file

@ -13,6 +13,8 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
//go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession QuicSession"
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler"
//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/lucas-clemente/quic-go unknownPacketHandler UnknownPacketHandler"
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager PacketHandlerManager"
//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer Multiplexer"
//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"

View file

@ -28,7 +28,7 @@ type connMultiplexer struct {
mutex sync.Mutex
conns map[net.PacketConn]connManager
newPacketHandlerManager func(net.PacketConn, int, utils.Logger, bool) packetHandlerManager // so it can be replaced in the tests
newPacketHandlerManager func(net.PacketConn, int, utils.Logger) packetHandlerManager // so it can be replaced in the tests
logger utils.Logger
}
@ -52,7 +52,7 @@ func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandle
p, ok := m.conns[c]
if !ok {
manager := m.newPacketHandlerManager(c, connIDLen, m.logger, true)
manager := m.newPacketHandlerManager(c, connIDLen, m.logger)
p = connManager{connIDLen: connIDLen, manager: manager}
m.conns[c] = p
}

View file

@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"net"
"strings"
"sync"
"time"
@ -24,6 +23,7 @@ type packetHandlerMap struct {
connIDLen int
handlers map[string] /* string(ConnectionID)*/ packetHandler
server unknownPacketHandler
closed bool
deleteClosedSessionsAfter time.Duration
@ -33,8 +33,7 @@ type packetHandlerMap struct {
var _ packetHandlerManager = &packetHandlerMap{}
// TODO(#561): remove the listen flag
func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger, listen bool) packetHandlerManager {
func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
m := &packetHandlerMap{
conn: conn,
connIDLen: connIDLen,
@ -42,19 +41,10 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
logger: logger,
}
if listen {
go m.listen()
}
go m.listen()
return m
}
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
h.mutex.RLock()
sess, ok := h.handlers[string(id)]
h.mutex.RUnlock()
return sess, ok
}
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
h.mutex.Lock()
h.handlers[string(id)] = handler
@ -62,18 +52,47 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
}
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.removeByConnectionIDAsString(string(id))
}
func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
h.mutex.Lock()
h.handlers[string(id)] = nil
h.handlers[id] = nil
h.mutex.Unlock()
time.AfterFunc(h.deleteClosedSessionsAfter, func() {
h.mutex.Lock()
delete(h.handlers, string(id))
delete(h.handlers, id)
h.mutex.Unlock()
})
}
func (h *packetHandlerMap) Close() error {
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
h.mutex.Lock()
h.server = s
h.mutex.Unlock()
}
func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
h.server = nil
var wg sync.WaitGroup
for id, handler := range h.handlers {
if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(id string, handler packetHandler) {
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
_ = handler.Close()
h.removeByConnectionIDAsString(id)
wg.Done()
}(id, handler)
}
}
h.mutex.Unlock()
wg.Wait()
}
func (h *packetHandlerMap) close(e error) error {
h.mutex.Lock()
if h.closed {
h.mutex.Unlock()
@ -86,12 +105,15 @@ func (h *packetHandlerMap) Close() error {
if handler != nil {
wg.Add(1)
go func(handler packetHandler) {
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
_ = handler.Close()
handler.destroy(e)
wg.Done()
}(handler)
}
}
if h.server != nil {
h.server.closeWithError(e)
}
h.mutex.Unlock()
wg.Wait()
return nil
@ -105,9 +127,7 @@ func (h *packetHandlerMap) listen() {
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err := h.conn.ReadFrom(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
h.Close()
}
h.close(err)
return
}
data = data[:n]
@ -127,15 +147,33 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
if err != nil {
return fmt.Errorf("error parsing invariant header: %s", err)
}
handler, ok := h.Get(iHdr.DestConnectionID)
if !ok {
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
if handler == nil {
h.mutex.RLock()
handler, ok := h.handlers[string(iHdr.DestConnectionID)]
server := h.server
h.mutex.RUnlock()
var sentBy protocol.Perspective
var version protocol.VersionNumber
var handlePacket func(*receivedPacket)
if ok && handler == nil {
// Late packet for closed session
return nil
}
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, handler.GetVersion())
if !ok {
if server == nil { // no server set
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
handlePacket = server.handlePacket
sentBy = protocol.PerspectiveClient
version = iHdr.Version
} else {
sentBy = handler.GetPerspective().Opposite()
version = handler.GetVersion()
handlePacket = handler.handlePacket
}
hdr, err := iHdr.Parse(r, sentBy, version)
if err != nil {
return fmt.Errorf("error parsing header: %s", err)
}
@ -150,7 +188,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
// TODO(#1312): implement parsing of compound packets
}
handler.handlePacket(&receivedPacket{
handlePacket(&receivedPacket{
remoteAddr: addr,
header: hdr,
data: packetData,

View file

@ -2,6 +2,7 @@ package quic
import (
"bytes"
"errors"
"time"
"github.com/golang/mock/gomock"
@ -18,66 +19,38 @@ var _ = Describe("Packet Handler Map", func() {
conn *mockPacketConn
)
getPacket := func(connID protocol.ConnectionID) []byte {
buf := &bytes.Buffer{}
err := (&wire.Header{
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
return buf.Bytes()
}
BeforeEach(func() {
conn = newMockPacketConn()
handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger, true).(*packetHandlerMap)
})
It("adds and gets", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
sess := &mockSession{}
handler.Add(connID, sess)
session, ok := handler.Get(connID)
Expect(ok).To(BeTrue())
Expect(session).To(Equal(sess))
})
It("deletes", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
handler.Add(connID, &mockSession{})
handler.Remove(connID)
session, ok := handler.Get(connID)
Expect(ok).To(BeTrue())
Expect(session).To(BeNil())
})
It("deletes nil session entries after a wait time", func() {
handler.deleteClosedSessionsAfter = 25 * time.Millisecond
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
handler.Add(connID, &mockSession{})
handler.Remove(connID)
Eventually(func() bool {
_, ok := handler.Get(connID)
return ok
}).Should(BeFalse())
handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap)
})
It("closes", func() {
sess1 := NewMockQuicSession(mockCtrl)
sess1.EXPECT().Close()
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().Close()
testErr := errors.New("test error ")
sess1 := NewMockPacketHandler(mockCtrl)
sess1.EXPECT().destroy(testErr)
sess2 := NewMockPacketHandler(mockCtrl)
sess2.EXPECT().destroy(testErr)
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1)
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2)
handler.Close()
handler.close(testErr)
})
Context("handling packets", func() {
getPacket := func(connID protocol.ConnectionID) []byte {
buf := &bytes.Buffer{}
err := (&wire.Header{
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
return buf.Bytes()
}
It("handles packets for different packet handlers on the same packet conn", func() {
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
packetHandler1 := NewMockQuicSession(mockCtrl)
packetHandler2 := NewMockQuicSession(mockCtrl)
packetHandler1 := NewMockPacketHandler(mockCtrl)
packetHandler2 := NewMockPacketHandler(mockCtrl)
handledPacket1 := make(chan struct{})
handledPacket2 := make(chan struct{})
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
@ -85,11 +58,13 @@ var _ = Describe("Packet Handler Map", func() {
close(handledPacket1)
})
packetHandler1.EXPECT().GetVersion()
packetHandler1.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.header.DestConnectionID).To(Equal(connID2))
close(handledPacket2)
})
packetHandler2.EXPECT().GetVersion()
packetHandler2.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
handler.Add(connID1, packetHandler1)
handler.Add(connID2, packetHandler2)
@ -99,8 +74,8 @@ var _ = Describe("Packet Handler Map", func() {
Eventually(handledPacket2).Should(BeClosed())
// makes the listen go routine return
packetHandler1.EXPECT().Close().AnyTimes()
packetHandler2.EXPECT().Close().AnyTimes()
packetHandler1.EXPECT().destroy(gomock.Any()).AnyTimes()
packetHandler2.EXPECT().destroy(gomock.Any()).AnyTimes()
close(conn.dataToRead)
})
@ -110,10 +85,20 @@ var _ = Describe("Packet Handler Map", func() {
Expect(err.Error()).To(ContainSubstring("error parsing invariant header:"))
})
It("deletes nil session entries after a wait time", func() {
handler.deleteClosedSessionsAfter = 10 * time.Millisecond
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Remove(connID)
Eventually(func() error {
return handler.handlePacket(nil, getPacket(connID))
}).Should(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
})
It("ignores packets arriving late for closed sessions", func() {
handler.deleteClosedSessionsAfter = time.Hour
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.Add(connID, NewMockQuicSession(mockCtrl))
handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Remove(connID)
err := handler.handlePacket(nil, getPacket(connID))
Expect(err).ToNot(HaveOccurred())
@ -127,8 +112,9 @@ var _ = Describe("Packet Handler Map", func() {
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockQuicSession(mockCtrl)
packetHandler := NewMockPacketHandler(mockCtrl)
packetHandler.EXPECT().GetVersion().Return(versionIETFFrames)
packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
handler.Add(connID, packetHandler)
hdr := &wire.Header{
IsLongHeader: true,
@ -148,8 +134,9 @@ var _ = Describe("Packet Handler Map", func() {
It("cuts packets at the Payload Length", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockQuicSession(mockCtrl)
packetHandler := NewMockPacketHandler(mockCtrl)
packetHandler.EXPECT().GetVersion().Return(versionIETFFrames)
packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
handler.Add(connID, packetHandler)
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.data).To(HaveLen(456))
@ -172,8 +159,9 @@ var _ = Describe("Packet Handler Map", func() {
It("closes the packet handlers when reading from the conn fails", func() {
done := make(chan struct{})
packetHandler := NewMockQuicSession(mockCtrl)
packetHandler.EXPECT().Close().Do(func() {
packetHandler := NewMockPacketHandler(mockCtrl)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
Expect(e).To(HaveOccurred())
close(done)
})
handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
@ -181,4 +169,38 @@ var _ = Describe("Packet Handler Map", func() {
Eventually(done).Should(BeClosed())
})
})
Context("running a server", func() {
It("adds a server", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.header.DestConnectionID).To(Equal(connID))
})
handler.SetServer(server)
Expect(handler.handlePacket(nil, p)).To(Succeed())
})
It("closes all server sessions", func() {
clientSess := NewMockPacketHandler(mockCtrl)
clientSess.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
serverSess := NewMockPacketHandler(mockCtrl)
serverSess.EXPECT().GetPerspective().Return(protocol.PerspectiveServer)
serverSess.EXPECT().Close()
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess)
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess)
handler.CloseServer()
})
It("stops handling packets with unknown connection IDs after the server is closed", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
handler.SetServer(server)
handler.CloseServer()
Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788"))
})
})
})

234
server.go
View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
@ -14,21 +13,27 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
// packetHandler handles packets
type packetHandler interface {
handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber
io.Closer
destroy(error)
GetVersion() protocol.VersionNumber
GetPerspective() protocol.Perspective
}
type unknownPacketHandler interface {
handlePacket(*receivedPacket)
closeWithError(error) error
}
type packetHandlerManager interface {
Add(protocol.ConnectionID, packetHandler)
Get(protocol.ConnectionID) (packetHandler, bool)
SetServer(unknownPacketHandler)
Remove(protocol.ConnectionID)
io.Closer
CloseServer()
}
type quicSession interface {
@ -84,6 +89,7 @@ type server struct {
}
var _ Listener = &server{}
var _ unknownPacketHandler = &server{}
// ListenAddr creates a QUIC server listening on a given address.
// The tls.Config must not be nil, the quic.Config may be nil.
@ -125,7 +131,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
}
}
logger := utils.DefaultLogger.WithPrefix("server")
sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
s := &server{
conn: conn,
tlsConf: tlsConf,
@ -133,11 +142,11 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
certChain: certChain,
scfg: scfg,
newSession: newSession,
sessionHandler: newPacketHandlerMap(conn, config.ConnectionIDLength, logger, false),
sessionHandler: sessionHandler,
sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}),
supportsTLS: supportsTLS,
logger: logger,
logger: utils.DefaultLogger.WithPrefix("server"),
}
s.setup()
if supportsTLS {
@ -145,7 +154,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
return nil, err
}
}
go s.serve()
sessionHandler.SetServer(s)
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil
}
@ -176,7 +185,8 @@ func (s *server) setupTLS() error {
case tlsSession := <-sessionChan:
// The connection ID is a randomly chosen 8 byte value.
// It is safe to assume that it doesn't collide with other randomly chosen values.
s.sessionHandler.Add(tlsSession.connID, tlsSession.sess)
serverSession := newServerSession(tlsSession.sess, s.config, s.logger)
s.sessionHandler.Add(tlsSession.connID, serverSession)
}
}
}()
@ -263,27 +273,6 @@ func populateServerConfig(config *Config) *Config {
}
}
// serve listens on an existing PacketConn
func (s *server) serve() {
for {
data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
n, remoteAddr, err := s.conn.ReadFrom(data)
if err != nil {
s.serverError = err
close(s.errorChan)
_ = s.Close()
return
}
data = data[:n]
if err := s.handlePacket(remoteAddr, data); err != nil {
s.logger.Errorf("error handling packet: %s", err.Error())
}
}
}
// Accept returns newly openend sessions
func (s *server) Accept() (Session, error) {
var sess Session
@ -297,10 +286,13 @@ func (s *server) Accept() (Session, error) {
// Close the server
func (s *server) Close() error {
s.sessionHandler.Close()
err := s.conn.Close()
<-s.errorChan // wait for serve() to return
return err
s.sessionHandler.CloseServer()
// TODO: close the conn if this server was started with ListenAddr() (but not with Listen(net.PacketConn))
if s.serverError == nil {
s.serverError = errors.New("server closed")
}
close(s.errorChan)
return nil
}
// Addr returns the server's network address
@ -308,157 +300,65 @@ func (s *server) Addr() net.Addr {
return s.conn.LocalAddr()
}
func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error {
rcvTime := time.Now()
r := bytes.NewReader(packet)
iHdr, err := wire.ParseInvariantHeader(r, s.config.ConnectionIDLength)
if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
}
session, sessionKnown := s.sessionHandler.Get(iHdr.DestConnectionID)
if sessionKnown && session == nil {
// Late packet for closed session
return nil
}
version := protocol.VersionUnknown
if sessionKnown {
version = session.GetVersion()
}
hdr, err := iHdr.Parse(r, protocol.PerspectiveClient, version)
if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
}
hdr.Raw = packet[:len(packet)-r.Len()]
packetData := packet[len(packet)-r.Len():]
if hdr.IsPublicHeader {
return s.handleGQUICPacket(session, hdr, packetData, remoteAddr, rcvTime)
}
return s.handleIETFQUICPacket(session, hdr, packetData, remoteAddr, rcvTime)
func (s *server) closeWithError(e error) error {
s.serverError = e
return s.Close()
}
func (s *server) handleIETFQUICPacket(
session packetHandler,
hdr *wire.Header,
packetData []byte,
remoteAddr net.Addr,
rcvTime time.Time,
) error {
if hdr.IsLongHeader {
if !s.supportsTLS {
return errors.New("Received an IETF QUIC Long Header")
}
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
}
packetData = packetData[:int(hdr.PayloadLen)]
// TODO(#1312): implement parsing of compound packets
switch hdr.Type {
case protocol.PacketTypeInitial:
go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
return nil
case protocol.PacketTypeHandshake:
// nothing to do here. Packet will be passed to the session.
default:
// Note that this also drops 0-RTT packets.
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
}
func (s *server) handlePacket(p *receivedPacket) {
if err := s.handlePacketImpl(p); err != nil {
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
}
if session == nil {
s.logger.Debugf("Received %s packet for unknown connection %s.", hdr.Type, hdr.DestConnectionID)
return nil
}
session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
return nil
}
func (s *server) handleGQUICPacket(
session packetHandler,
hdr *wire.Header,
packetData []byte,
remoteAddr net.Addr,
rcvTime time.Time,
) error {
// ignore all Public Reset packets
if hdr.ResetFlag {
s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID)
func (s *server) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
version := hdr.Version
if hdr.Type == protocol.PacketTypeInitial {
go s.serverTLS.HandleInitial(p.remoteAddr, hdr, p.data)
return nil
}
sessionKnown := session != nil
// If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset
// This should only happen after a server restart, when we still receive packets for connections that we lost the state for.
if !sessionKnown && !hdr.VersionFlag {
_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), remoteAddr)
if !hdr.VersionFlag {
_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr)
return err
}
// a session is only created once the client sent a supported version
// if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated
// it is safe to drop it
if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
return nil
// This is (potentially) a Client Hello.
// Make sure it has the minimum required size before spending any more ressources on it.
if len(p.data) < protocol.MinClientHelloSize {
return errors.New("dropping small packet for unknown connection")
}
// send a Version Negotiation Packet if the client is speaking a different protocol version
// since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet
if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
// drop packets that are too small to be valid first packets
if len(packetData) < protocol.MinClientHelloSize {
return errors.New("dropping small packet with unknown version")
}
s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version)
_, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), remoteAddr)
if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, version) {
s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", version)
_, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), p.remoteAddr)
return err
}
if !sessionKnown {
// This is (potentially) a Client Hello.
// Make sure it has the minimum required size before spending any more ressources on it.
if len(packetData) < protocol.MinClientHelloSize {
return errors.New("dropping small packet for unknown connection")
}
version := hdr.Version
if !protocol.IsSupportedVersion(s.config.Versions, version) {
return errors.New("Server BUG: negotiated version not supported")
}
s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, remoteAddr)
sess, err := s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionRunner,
version,
hdr.DestConnectionID,
s.scfg,
s.tlsConf,
s.config,
s.logger,
)
if err != nil {
return err
}
s.sessionHandler.Add(hdr.DestConnectionID, sess)
go sess.run()
session = sess
if !protocol.IsSupportedVersion(s.config.Versions, version) {
return errors.New("Server BUG: negotiated version not supported")
}
session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, p.remoteAddr)
sess, err := s.newSession(
&conn{pconn: s.conn, currentAddr: p.remoteAddr},
s.sessionRunner,
version,
hdr.DestConnectionID,
s.scfg,
s.tlsConf,
s.config,
s.logger,
)
if err != nil {
return err
}
s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger))
go sess.run()
sess.handlePacket(p)
return nil
}

63
server_session.go Normal file
View file

@ -0,0 +1,63 @@
package quic
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type serverSession struct {
quicSession
config *Config
logger utils.Logger
}
var _ packetHandler = &serverSession{}
func newServerSession(sess quicSession, config *Config, logger utils.Logger) packetHandler {
return &serverSession{
quicSession: sess,
config: config,
logger: logger,
}
}
func (s *serverSession) handlePacket(p *receivedPacket) {
if err := s.handlePacketImpl(p); err != nil {
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
}
}
func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
// ignore all Public Reset packets
if hdr.ResetFlag {
return fmt.Errorf("Received unexpected Public Reset for connection %s", hdr.DestConnectionID)
}
// Probably an old packet that was sent by the client before the version was negotiated.
// It is safe to drop it.
if (hdr.VersionFlag || hdr.IsLongHeader) && hdr.Version != s.quicSession.GetVersion() {
return nil
}
if hdr.IsLongHeader {
switch hdr.Type {
case protocol.PacketTypeHandshake:
// nothing to do here. Packet will be passed to the session.
default:
// Note that this also drops 0-RTT packets.
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
}
}
s.quicSession.handlePacket(p)
return nil
}
func (s *serverSession) GetPerspective() protocol.Perspective {
return protocol.PerspectiveServer
}

101
server_session_test.go Normal file
View file

@ -0,0 +1,101 @@
package quic
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Server Session", func() {
var (
qsess *MockQuicSession
sess *serverSession
)
BeforeEach(func() {
qsess = NewMockQuicSession(mockCtrl)
sess = newServerSession(qsess, &Config{}, utils.DefaultLogger).(*serverSession)
})
It("handles packets", func() {
p := &receivedPacket{
header: &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}},
}
qsess.EXPECT().handlePacket(p)
sess.handlePacket(p)
})
It("ignores Public Resets", func() {
p := &receivedPacket{
header: &wire.Header{
ResetFlag: true,
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
},
}
err := sess.handlePacketImpl(p)
Expect(err).To(MatchError("Received unexpected Public Reset for connection 0xdeadbeef"))
})
It("ignores delayed packets with mismatching versions, for gQUIC", func() {
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
// don't EXPECT any calls to handlePacket()
p := &receivedPacket{
header: &wire.Header{
VersionFlag: true,
Version: protocol.VersionNumber(123),
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
},
}
err := sess.handlePacketImpl(p)
Expect(err).ToNot(HaveOccurred())
})
It("ignores delayed packets with mismatching versions, for IETF QUIC", func() {
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
// don't EXPECT any calls to handlePacket()
p := &receivedPacket{
header: &wire.Header{
IsLongHeader: true,
Version: protocol.VersionNumber(123),
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
},
}
err := sess.handlePacketImpl(p)
Expect(err).ToNot(HaveOccurred())
})
It("ignores packets with the wrong Long Header type", func() {
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
p := &receivedPacket{
header: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketType0RTT,
Version: protocol.VersionNumber(100),
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
},
}
err := sess.handlePacketImpl(p)
Expect(err).To(MatchError("Received unsupported packet type: 0-RTT Protected"))
})
It("passes on Handshake packets", func() {
p := &receivedPacket{
header: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Version: protocol.VersionNumber(100),
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
},
}
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
qsess.EXPECT().handlePacket(p)
Expect(sess.handlePacketImpl(p)).To(Succeed())
})
It("has the right perspective", func() {
Expect(sess.GetPerspective()).To(Equal(protocol.PerspectiveServer))
})
})

View file

@ -14,7 +14,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -27,6 +26,8 @@ type mockSession struct {
runner sessionRunner
}
func (s *mockSession) GetPerspective() protocol.Perspective { panic("not implemented") }
var _ = Describe("Server", func() {
var (
conn *mockPacketConn
@ -89,7 +90,7 @@ var _ = Describe("Server", func() {
Context("with mock session", func() {
var (
serv *server
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
firstPacket *receivedPacket
connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
sessions = make([]*MockQuicSession, 0)
sessionHandler *MockPacketHandlerManager
@ -126,9 +127,16 @@ var _ = Describe("Server", func() {
serv.setup()
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]))
firstPacket = []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
firstPacket = append(append(firstPacket, b.Bytes()...), 0x01)
firstPacket = append(firstPacket, bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)...) // add padding
firstPacket = &receivedPacket{
header: &wire.Header{
VersionFlag: true,
Version: serv.config.Versions[0],
DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6},
PacketNumber: 1,
},
data: bytes.Repeat([]byte{0}, protocol.MinClientHelloSize),
rcvTime: time.Now(),
}
})
AfterEach(func() {
@ -150,12 +158,10 @@ var _ = Describe("Server", func() {
s.EXPECT().run().Do(func() { close(run) })
sessions = append(sessions, s)
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
Expect(sess.(*mockSession).connID).To(Equal(connID))
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(cid protocol.ConnectionID, _ packetHandler) {
Expect(cid).To(Equal(connID))
})
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
Eventually(run).Should(BeClosed())
})
@ -165,7 +171,8 @@ var _ = Describe("Server", func() {
err := serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
added := make(chan struct{})
sessionHandler.EXPECT().Add(connID, sess).Do(func(protocol.ConnectionID, packetHandler) {
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, ph packetHandler) {
Expect(ph.GetPerspective()).To(Equal(protocol.PerspectiveServer))
close(added)
})
serv.serverTLS.sessionChan <- tlsSession{
@ -184,17 +191,15 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := serv.Accept()
_, err := serv.Accept()
Expect(err).ToNot(HaveOccurred())
Expect(sess.(*mockSession).connID).To(Equal(connID))
close(done)
}()
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
Consistently(done).ShouldNot(BeClosed())
sess.(*mockSession).runner.onHandshakeComplete(sess.(Session))
sess.(*serverSession).quicSession.(*mockSession).runner.onHandshakeComplete(sess.(Session))
})
err := serv.handlePacket(nil, firstPacket)
err := serv.handlePacketImpl(firstPacket)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Eventually(run).Should(BeClosed())
@ -212,45 +217,20 @@ var _ = Describe("Server", func() {
serv.Accept()
close(done)
}()
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(protocol.ConnectionID, packetHandler) {
run <- errors.New("handshake error")
})
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
sessionHandler.EXPECT().Close()
close(serv.errorChan)
serv.Close()
Eventually(done).Should(BeClosed())
})
It("assigns packets to existing sessions", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().GetVersion()
sessionHandler.EXPECT().Get(connID).Return(sess, true)
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
Expect(err).ToNot(HaveOccurred())
})
It("closes the sessionHandler and the connection when Close is called", func() {
go func() {
defer GinkgoRecover()
serv.serve()
}()
// close the server
sessionHandler.EXPECT().Close().AnyTimes()
It("closes the sessionHandler when Close is called", func() {
sessionHandler.EXPECT().CloseServer()
Expect(serv.Close()).To(Succeed())
Expect(conn.closed).To(BeTrue())
})
It("ignores packets for closed sessions", func() {
sessionHandler.EXPECT().Get(connID).Return(nil, true)
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
})
It("works if no quic.Config is given", func(done Done) {
@ -264,163 +244,56 @@ var _ = Describe("Server", func() {
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
var returned bool
go func() {
defer GinkgoRecover()
_, err := ln.Accept()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
returned = true
}()
ln.Close()
Eventually(func() bool { return returned }).Should(BeTrue())
})
It("errors when encountering a connection error", func() {
testErr := errors.New("connection error")
conn.readErr = testErr
sessionHandler.EXPECT().Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.serve()
ln.Accept()
close(done)
}()
_, err := serv.Accept()
Expect(err).To(MatchError(testErr))
ln.Close()
Eventually(done).Should(BeClosed())
})
It("ignores delayed packets with mismatching versions", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion()
// don't EXPECT any handlePacket() calls to this session
sessionHandler.EXPECT().Get(connID).Return(sess, true)
b := &bytes.Buffer{}
// add an unsupported version
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1))
data = append(append(data, b.Bytes()...), 0x01)
err := serv.handlePacket(nil, data)
Expect(err).ToNot(HaveOccurred())
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
It("returns Accept when it is closed", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept()
Expect(err).To(MatchError("server closed"))
close(done)
}()
sessionHandler.EXPECT().CloseServer()
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("errors on invalid public header", func() {
err := serv.handlePacket(nil, nil)
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
})
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
sessionHandler.EXPECT().Get(connID).Return(sess, true)
serv.supportsTLS = true
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 1000,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).To(MatchError("packet payload (456 bytes) is smaller than the expected payload length (1000 bytes)"))
})
It("cuts packets at the payload length", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
Expect(packet.data).To(HaveLen(123))
})
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
sessionHandler.EXPECT().Get(connID).Return(sess, true)
serv.supportsTLS = true
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).ToNot(HaveOccurred())
})
It("drops packets with invalid packet types", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
sessionHandler.EXPECT().Get(connID).Return(sess, true)
serv.supportsTLS = true
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).To(MatchError("Received unsupported packet type: Retry"))
})
It("ignores Public Resets", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
sessionHandler.EXPECT().Get(connID).Return(sess, true)
err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
Expect(err).ToNot(HaveOccurred())
It("returns Accept with the right error when closeWithError is called", func() {
testErr := errors.New("connection error")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept()
Expect(err).To(MatchError(testErr))
close(done)
}()
sessionHandler.EXPECT().CloseServer()
serv.closeWithError(testErr)
Eventually(done).Should(BeClosed())
})
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
config.Versions = []protocol.VersionNumber{99}
b := &bytes.Buffer{}
hdr := wire.Header{
VersionFlag: true,
DestConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
p := &receivedPacket{
header: &wire.Header{
VersionFlag: true,
DestConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
},
data: make([]byte, protocol.MinClientHelloSize),
}
Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
serv.conn = conn
sessionHandler.EXPECT().Get(connID)
err := serv.handlePacket(nil, b.Bytes())
Expect(serv.handlePacketImpl(p)).To(Succeed())
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
Expect(err).ToNot(HaveOccurred())
})
It("doesn't respond with a version negotiation packet if the first packet is too small", func() {
b := &bytes.Buffer{}
hdr := wire.Header{
VersionFlag: true,
DestConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small
serv.conn = conn
sessionHandler.EXPECT().Get(connID)
err := serv.handlePacket(udpAddr, b.Bytes())
Expect(err).To(MatchError("dropping small packet with unknown version"))
Expect(conn.dataWritten.Len()).Should(BeZero())
})
})
@ -523,8 +396,11 @@ var _ = Describe("Server", func() {
})
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
config.Versions = append(config.Versions, protocol.VersionTLS)
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
b := &bytes.Buffer{}
hdr := wire.Header{
Type: protocol.PacketTypeInitial,
@ -536,13 +412,10 @@ var _ = Describe("Server", func() {
Version: 0x1234,
PayloadLen: protocol.MinInitialPacketSize,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)).To(Succeed())
b.Write(bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) // add a fake CHLO
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -568,51 +441,6 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
b := &bytes.Buffer{}
hdr := wire.Header{
Type: protocol.PacketTypeInitial,
IsLongHeader: true,
DestConnectionID: connID,
SrcConnectionID: connID,
PacketNumber: 0x55,
PacketNumberLen: protocol.PacketNumberLen1,
Version: protocol.VersionTLS,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
})
It("ignores non-Initial Long Header packets for unknown connections", func() {
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
b := &bytes.Buffer{}
hdr := wire.Header{
Type: protocol.PacketTypeHandshake,
IsLongHeader: true,
DestConnectionID: connID,
SrcConnectionID: connID,
PacketNumber: 0x55,
PacketNumberLen: protocol.PacketNumberLen1,
Version: protocol.VersionTLS,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
})
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
conn.dataReadFrom = udpAddr
conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}

View file

@ -17,7 +17,7 @@ import (
type tlsSession struct {
connID protocol.ConnectionID
sess packetHandler
sess quicSession
}
type serverTLS struct {
@ -126,7 +126,7 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea
return err
}
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, protocol.ConnectionID, error) {
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (quicSession, protocol.ConnectionID, error) {
if hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
return nil, nil, errors.New("dropping Initial packet with too short connection ID")
}
@ -164,7 +164,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
return sess, connID, nil
}
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, protocol.ConnectionID, error) {
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (quicSession, protocol.ConnectionID, error) {
version := hdr.Version
bc := handshake.NewCryptoStreamConn(remoteAddr)
bc.AddDataForReading(frame.Data)