mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
also use the multiplexer for the server
This commit is contained in:
parent
c8d20e86d7
commit
ad5a3e2fa0
15 changed files with 631 additions and 512 deletions
13
client.go
13
client.go
|
@ -544,9 +544,22 @@ func (c *client) Close() error {
|
|||
return c.session.Close()
|
||||
}
|
||||
|
||||
func (c *client) destroy(e error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.session == nil {
|
||||
return
|
||||
}
|
||||
c.session.destroy(e)
|
||||
}
|
||||
|
||||
func (c *client) GetVersion() protocol.VersionNumber {
|
||||
c.mutex.Lock()
|
||||
v := c.version
|
||||
c.mutex.Unlock()
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *client) GetPerspective() protocol.Perspective {
|
||||
return protocol.PerspectiveClient
|
||||
}
|
||||
|
|
|
@ -9,6 +9,11 @@ const (
|
|||
PerspectiveClient Perspective = 2
|
||||
)
|
||||
|
||||
// Opposite returns the perspective of the peer
|
||||
func (p Perspective) Opposite() Perspective {
|
||||
return 3 - p
|
||||
}
|
||||
|
||||
func (p Perspective) String() string {
|
||||
switch p {
|
||||
case PerspectiveServer:
|
||||
|
|
|
@ -11,4 +11,9 @@ var _ = Describe("Perspective", func() {
|
|||
Expect(PerspectiveServer.String()).To(Equal("Server"))
|
||||
Expect(Perspective(0).String()).To(Equal("invalid perspective"))
|
||||
})
|
||||
|
||||
It("returns the opposite", func() {
|
||||
Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer))
|
||||
Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -44,29 +44,14 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockPacketHandlerManager) Close() error {
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
// CloseServer mocks base method
|
||||
func (m *MockPacketHandlerManager) CloseServer() {
|
||||
m.ctrl.Call(m, "CloseServer")
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Close() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close))
|
||||
}
|
||||
|
||||
// Get mocks base method
|
||||
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||
ret := m.ctrl.Call(m, "Get", arg0)
|
||||
ret0, _ := ret[0].(packetHandler)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
|
||||
// CloseServer indicates an expected call of CloseServer
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
|
@ -78,3 +63,13 @@ func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
|
|||
func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0)
|
||||
}
|
||||
|
||||
// SetServer mocks base method
|
||||
func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) {
|
||||
m.ctrl.Call(m, "SetServer", arg0)
|
||||
}
|
||||
|
||||
// SetServer indicates an expected call of SetServer
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0)
|
||||
}
|
||||
|
|
91
mock_packet_handler_test.go
Normal file
91
mock_packet_handler_test.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: PacketHandler)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockPacketHandler is a mock of PacketHandler interface
|
||||
type MockPacketHandler struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPacketHandlerMockRecorder
|
||||
}
|
||||
|
||||
// MockPacketHandlerMockRecorder is the mock recorder for MockPacketHandler
|
||||
type MockPacketHandlerMockRecorder struct {
|
||||
mock *MockPacketHandler
|
||||
}
|
||||
|
||||
// NewMockPacketHandler creates a new mock instance
|
||||
func NewMockPacketHandler(ctrl *gomock.Controller) *MockPacketHandler {
|
||||
mock := &MockPacketHandler{ctrl: ctrl}
|
||||
mock.recorder = &MockPacketHandlerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockPacketHandler) Close() error {
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockPacketHandlerMockRecorder) Close() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandler)(nil).Close))
|
||||
}
|
||||
|
||||
// GetPerspective mocks base method
|
||||
func (m *MockPacketHandler) GetPerspective() protocol.Perspective {
|
||||
ret := m.ctrl.Call(m, "GetPerspective")
|
||||
ret0, _ := ret[0].(protocol.Perspective)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetPerspective indicates an expected call of GetPerspective
|
||||
func (mr *MockPacketHandlerMockRecorder) GetPerspective() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPerspective", reflect.TypeOf((*MockPacketHandler)(nil).GetPerspective))
|
||||
}
|
||||
|
||||
// GetVersion mocks base method
|
||||
func (m *MockPacketHandler) GetVersion() protocol.VersionNumber {
|
||||
ret := m.ctrl.Call(m, "GetVersion")
|
||||
ret0, _ := ret[0].(protocol.VersionNumber)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetVersion indicates an expected call of GetVersion
|
||||
func (mr *MockPacketHandlerMockRecorder) GetVersion() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockPacketHandler)(nil).GetVersion))
|
||||
}
|
||||
|
||||
// destroy mocks base method
|
||||
func (m *MockPacketHandler) destroy(arg0 error) {
|
||||
m.ctrl.Call(m, "destroy", arg0)
|
||||
}
|
||||
|
||||
// destroy indicates an expected call of destroy
|
||||
func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0)
|
||||
}
|
||||
|
||||
// handlePacket mocks base method
|
||||
func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) {
|
||||
m.ctrl.Call(m, "handlePacket", arg0)
|
||||
}
|
||||
|
||||
// handlePacket indicates an expected call of handlePacket
|
||||
func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0)
|
||||
}
|
56
mock_unknown_packet_handler_test.go
Normal file
56
mock_unknown_packet_handler_test.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: UnknownPacketHandler)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockUnknownPacketHandler is a mock of UnknownPacketHandler interface
|
||||
type MockUnknownPacketHandler struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUnknownPacketHandlerMockRecorder
|
||||
}
|
||||
|
||||
// MockUnknownPacketHandlerMockRecorder is the mock recorder for MockUnknownPacketHandler
|
||||
type MockUnknownPacketHandlerMockRecorder struct {
|
||||
mock *MockUnknownPacketHandler
|
||||
}
|
||||
|
||||
// NewMockUnknownPacketHandler creates a new mock instance
|
||||
func NewMockUnknownPacketHandler(ctrl *gomock.Controller) *MockUnknownPacketHandler {
|
||||
mock := &MockUnknownPacketHandler{ctrl: ctrl}
|
||||
mock.recorder = &MockUnknownPacketHandlerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// closeWithError mocks base method
|
||||
func (m *MockUnknownPacketHandler) closeWithError(arg0 error) error {
|
||||
ret := m.ctrl.Call(m, "closeWithError", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// closeWithError indicates an expected call of closeWithError
|
||||
func (mr *MockUnknownPacketHandlerMockRecorder) closeWithError(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).closeWithError), arg0)
|
||||
}
|
||||
|
||||
// handlePacket mocks base method
|
||||
func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) {
|
||||
m.ctrl.Call(m, "handlePacket", arg0)
|
||||
}
|
||||
|
||||
// handlePacket indicates an expected call of handlePacket
|
||||
func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0)
|
||||
}
|
|
@ -13,6 +13,8 @@ package quic
|
|||
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession QuicSession"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/lucas-clemente/quic-go unknownPacketHandler UnknownPacketHandler"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager PacketHandlerManager"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer Multiplexer"
|
||||
//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"
|
||||
|
|
|
@ -28,7 +28,7 @@ type connMultiplexer struct {
|
|||
mutex sync.Mutex
|
||||
|
||||
conns map[net.PacketConn]connManager
|
||||
newPacketHandlerManager func(net.PacketConn, int, utils.Logger, bool) packetHandlerManager // so it can be replaced in the tests
|
||||
newPacketHandlerManager func(net.PacketConn, int, utils.Logger) packetHandlerManager // so it can be replaced in the tests
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandle
|
|||
|
||||
p, ok := m.conns[c]
|
||||
if !ok {
|
||||
manager := m.newPacketHandlerManager(c, connIDLen, m.logger, true)
|
||||
manager := m.newPacketHandlerManager(c, connIDLen, m.logger)
|
||||
p = connManager{connIDLen: connIDLen, manager: manager}
|
||||
m.conns[c] = p
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -24,6 +23,7 @@ type packetHandlerMap struct {
|
|||
connIDLen int
|
||||
|
||||
handlers map[string] /* string(ConnectionID)*/ packetHandler
|
||||
server unknownPacketHandler
|
||||
closed bool
|
||||
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
|
@ -33,8 +33,7 @@ type packetHandlerMap struct {
|
|||
|
||||
var _ packetHandlerManager = &packetHandlerMap{}
|
||||
|
||||
// TODO(#561): remove the listen flag
|
||||
func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger, listen bool) packetHandlerManager {
|
||||
func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
|
||||
m := &packetHandlerMap{
|
||||
conn: conn,
|
||||
connIDLen: connIDLen,
|
||||
|
@ -42,19 +41,10 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger
|
|||
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
||||
logger: logger,
|
||||
}
|
||||
if listen {
|
||||
go m.listen()
|
||||
}
|
||||
go m.listen()
|
||||
return m
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
|
||||
h.mutex.RLock()
|
||||
sess, ok := h.handlers[string(id)]
|
||||
h.mutex.RUnlock()
|
||||
return sess, ok
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
|
||||
h.mutex.Lock()
|
||||
h.handlers[string(id)] = handler
|
||||
|
@ -62,18 +52,47 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
|
|||
}
|
||||
|
||||
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||
h.removeByConnectionIDAsString(string(id))
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
|
||||
h.mutex.Lock()
|
||||
h.handlers[string(id)] = nil
|
||||
h.handlers[id] = nil
|
||||
h.mutex.Unlock()
|
||||
|
||||
time.AfterFunc(h.deleteClosedSessionsAfter, func() {
|
||||
h.mutex.Lock()
|
||||
delete(h.handlers, string(id))
|
||||
delete(h.handlers, id)
|
||||
h.mutex.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Close() error {
|
||||
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
|
||||
h.mutex.Lock()
|
||||
h.server = s
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) CloseServer() {
|
||||
h.mutex.Lock()
|
||||
h.server = nil
|
||||
var wg sync.WaitGroup
|
||||
for id, handler := range h.handlers {
|
||||
if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer {
|
||||
wg.Add(1)
|
||||
go func(id string, handler packetHandler) {
|
||||
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||||
_ = handler.Close()
|
||||
h.removeByConnectionIDAsString(id)
|
||||
wg.Done()
|
||||
}(id, handler)
|
||||
}
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) close(e error) error {
|
||||
h.mutex.Lock()
|
||||
if h.closed {
|
||||
h.mutex.Unlock()
|
||||
|
@ -86,12 +105,15 @@ func (h *packetHandlerMap) Close() error {
|
|||
if handler != nil {
|
||||
wg.Add(1)
|
||||
go func(handler packetHandler) {
|
||||
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||||
_ = handler.Close()
|
||||
handler.destroy(e)
|
||||
wg.Done()
|
||||
}(handler)
|
||||
}
|
||||
}
|
||||
|
||||
if h.server != nil {
|
||||
h.server.closeWithError(e)
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
return nil
|
||||
|
@ -105,9 +127,7 @@ func (h *packetHandlerMap) listen() {
|
|||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
n, addr, err := h.conn.ReadFrom(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
h.Close()
|
||||
}
|
||||
h.close(err)
|
||||
return
|
||||
}
|
||||
data = data[:n]
|
||||
|
@ -127,15 +147,33 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("error parsing invariant header: %s", err)
|
||||
}
|
||||
handler, ok := h.Get(iHdr.DestConnectionID)
|
||||
if !ok {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
||||
}
|
||||
if handler == nil {
|
||||
|
||||
h.mutex.RLock()
|
||||
handler, ok := h.handlers[string(iHdr.DestConnectionID)]
|
||||
server := h.server
|
||||
h.mutex.RUnlock()
|
||||
|
||||
var sentBy protocol.Perspective
|
||||
var version protocol.VersionNumber
|
||||
var handlePacket func(*receivedPacket)
|
||||
if ok && handler == nil {
|
||||
// Late packet for closed session
|
||||
return nil
|
||||
}
|
||||
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, handler.GetVersion())
|
||||
if !ok {
|
||||
if server == nil { // no server set
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
||||
}
|
||||
handlePacket = server.handlePacket
|
||||
sentBy = protocol.PerspectiveClient
|
||||
version = iHdr.Version
|
||||
} else {
|
||||
sentBy = handler.GetPerspective().Opposite()
|
||||
version = handler.GetVersion()
|
||||
handlePacket = handler.handlePacket
|
||||
}
|
||||
|
||||
hdr, err := iHdr.Parse(r, sentBy, version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing header: %s", err)
|
||||
}
|
||||
|
@ -150,7 +188,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
|||
// TODO(#1312): implement parsing of compound packets
|
||||
}
|
||||
|
||||
handler.handlePacket(&receivedPacket{
|
||||
handlePacket(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: packetData,
|
||||
|
|
|
@ -2,6 +2,7 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
@ -18,66 +19,38 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
conn *mockPacketConn
|
||||
)
|
||||
|
||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
err := (&wire.Header{
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
conn = newMockPacketConn()
|
||||
handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger, true).(*packetHandlerMap)
|
||||
})
|
||||
|
||||
It("adds and gets", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
sess := &mockSession{}
|
||||
handler.Add(connID, sess)
|
||||
session, ok := handler.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(session).To(Equal(sess))
|
||||
})
|
||||
|
||||
It("deletes", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
handler.Add(connID, &mockSession{})
|
||||
handler.Remove(connID)
|
||||
session, ok := handler.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(session).To(BeNil())
|
||||
})
|
||||
|
||||
It("deletes nil session entries after a wait time", func() {
|
||||
handler.deleteClosedSessionsAfter = 25 * time.Millisecond
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
handler.Add(connID, &mockSession{})
|
||||
handler.Remove(connID)
|
||||
Eventually(func() bool {
|
||||
_, ok := handler.Get(connID)
|
||||
return ok
|
||||
}).Should(BeFalse())
|
||||
handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap)
|
||||
})
|
||||
|
||||
It("closes", func() {
|
||||
sess1 := NewMockQuicSession(mockCtrl)
|
||||
sess1.EXPECT().Close()
|
||||
sess2 := NewMockQuicSession(mockCtrl)
|
||||
sess2.EXPECT().Close()
|
||||
testErr := errors.New("test error ")
|
||||
sess1 := NewMockPacketHandler(mockCtrl)
|
||||
sess1.EXPECT().destroy(testErr)
|
||||
sess2 := NewMockPacketHandler(mockCtrl)
|
||||
sess2.EXPECT().destroy(testErr)
|
||||
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1)
|
||||
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2)
|
||||
handler.Close()
|
||||
handler.close(testErr)
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
err := (&wire.Header{
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
packetHandler1 := NewMockQuicSession(mockCtrl)
|
||||
packetHandler2 := NewMockQuicSession(mockCtrl)
|
||||
packetHandler1 := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler2 := NewMockPacketHandler(mockCtrl)
|
||||
handledPacket1 := make(chan struct{})
|
||||
handledPacket2 := make(chan struct{})
|
||||
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
|
@ -85,11 +58,13 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
close(handledPacket1)
|
||||
})
|
||||
packetHandler1.EXPECT().GetVersion()
|
||||
packetHandler1.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
|
||||
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.header.DestConnectionID).To(Equal(connID2))
|
||||
close(handledPacket2)
|
||||
})
|
||||
packetHandler2.EXPECT().GetVersion()
|
||||
packetHandler2.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
|
||||
handler.Add(connID1, packetHandler1)
|
||||
handler.Add(connID2, packetHandler2)
|
||||
|
||||
|
@ -99,8 +74,8 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
Eventually(handledPacket2).Should(BeClosed())
|
||||
|
||||
// makes the listen go routine return
|
||||
packetHandler1.EXPECT().Close().AnyTimes()
|
||||
packetHandler2.EXPECT().Close().AnyTimes()
|
||||
packetHandler1.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
packetHandler2.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
close(conn.dataToRead)
|
||||
})
|
||||
|
||||
|
@ -110,10 +85,20 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
Expect(err.Error()).To(ContainSubstring("error parsing invariant header:"))
|
||||
})
|
||||
|
||||
It("deletes nil session entries after a wait time", func() {
|
||||
handler.deleteClosedSessionsAfter = 10 * time.Millisecond
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
handler.Remove(connID)
|
||||
Eventually(func() error {
|
||||
return handler.handlePacket(nil, getPacket(connID))
|
||||
}).Should(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
||||
})
|
||||
|
||||
It("ignores packets arriving late for closed sessions", func() {
|
||||
handler.deleteClosedSessionsAfter = time.Hour
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
handler.Add(connID, NewMockQuicSession(mockCtrl))
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
handler.Remove(connID)
|
||||
err := handler.handlePacket(nil, getPacket(connID))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -127,8 +112,9 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
|
||||
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
packetHandler := NewMockQuicSession(mockCtrl)
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler.EXPECT().GetVersion().Return(versionIETFFrames)
|
||||
packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
|
||||
handler.Add(connID, packetHandler)
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
|
@ -148,8 +134,9 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
|
||||
It("cuts packets at the Payload Length", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
packetHandler := NewMockQuicSession(mockCtrl)
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler.EXPECT().GetVersion().Return(versionIETFFrames)
|
||||
packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
|
||||
handler.Add(connID, packetHandler)
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.data).To(HaveLen(456))
|
||||
|
@ -172,8 +159,9 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
|
||||
It("closes the packet handlers when reading from the conn fails", func() {
|
||||
done := make(chan struct{})
|
||||
packetHandler := NewMockQuicSession(mockCtrl)
|
||||
packetHandler.EXPECT().Close().Do(func() {
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
|
||||
Expect(e).To(HaveOccurred())
|
||||
close(done)
|
||||
})
|
||||
handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
|
||||
|
@ -181,4 +169,38 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("running a server", func() {
|
||||
It("adds a server", func() {
|
||||
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
|
||||
p := getPacket(connID)
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.header.DestConnectionID).To(Equal(connID))
|
||||
})
|
||||
handler.SetServer(server)
|
||||
Expect(handler.handlePacket(nil, p)).To(Succeed())
|
||||
})
|
||||
|
||||
It("closes all server sessions", func() {
|
||||
clientSess := NewMockPacketHandler(mockCtrl)
|
||||
clientSess.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
|
||||
serverSess := NewMockPacketHandler(mockCtrl)
|
||||
serverSess.EXPECT().GetPerspective().Return(protocol.PerspectiveServer)
|
||||
serverSess.EXPECT().Close()
|
||||
|
||||
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess)
|
||||
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess)
|
||||
handler.CloseServer()
|
||||
})
|
||||
|
||||
It("stops handling packets with unknown connection IDs after the server is closed", func() {
|
||||
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
|
||||
p := getPacket(connID)
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
handler.SetServer(server)
|
||||
handler.CloseServer()
|
||||
Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
234
server.go
234
server.go
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -14,21 +13,27 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// packetHandler handles packets
|
||||
type packetHandler interface {
|
||||
handlePacket(*receivedPacket)
|
||||
GetVersion() protocol.VersionNumber
|
||||
io.Closer
|
||||
destroy(error)
|
||||
GetVersion() protocol.VersionNumber
|
||||
GetPerspective() protocol.Perspective
|
||||
}
|
||||
|
||||
type unknownPacketHandler interface {
|
||||
handlePacket(*receivedPacket)
|
||||
closeWithError(error) error
|
||||
}
|
||||
|
||||
type packetHandlerManager interface {
|
||||
Add(protocol.ConnectionID, packetHandler)
|
||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||
SetServer(unknownPacketHandler)
|
||||
Remove(protocol.ConnectionID)
|
||||
io.Closer
|
||||
CloseServer()
|
||||
}
|
||||
|
||||
type quicSession interface {
|
||||
|
@ -84,6 +89,7 @@ type server struct {
|
|||
}
|
||||
|
||||
var _ Listener = &server{}
|
||||
var _ unknownPacketHandler = &server{}
|
||||
|
||||
// ListenAddr creates a QUIC server listening on a given address.
|
||||
// The tls.Config must not be nil, the quic.Config may be nil.
|
||||
|
@ -125,7 +131,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
|||
}
|
||||
}
|
||||
|
||||
logger := utils.DefaultLogger.WithPrefix("server")
|
||||
sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &server{
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
|
@ -133,11 +142,11 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
|||
certChain: certChain,
|
||||
scfg: scfg,
|
||||
newSession: newSession,
|
||||
sessionHandler: newPacketHandlerMap(conn, config.ConnectionIDLength, logger, false),
|
||||
sessionHandler: sessionHandler,
|
||||
sessionQueue: make(chan Session, 5),
|
||||
errorChan: make(chan struct{}),
|
||||
supportsTLS: supportsTLS,
|
||||
logger: logger,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
}
|
||||
s.setup()
|
||||
if supportsTLS {
|
||||
|
@ -145,7 +154,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
go s.serve()
|
||||
sessionHandler.SetServer(s)
|
||||
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
||||
return s, nil
|
||||
}
|
||||
|
@ -176,7 +185,8 @@ func (s *server) setupTLS() error {
|
|||
case tlsSession := <-sessionChan:
|
||||
// The connection ID is a randomly chosen 8 byte value.
|
||||
// It is safe to assume that it doesn't collide with other randomly chosen values.
|
||||
s.sessionHandler.Add(tlsSession.connID, tlsSession.sess)
|
||||
serverSession := newServerSession(tlsSession.sess, s.config, s.logger)
|
||||
s.sessionHandler.Add(tlsSession.connID, serverSession)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -263,27 +273,6 @@ func populateServerConfig(config *Config) *Config {
|
|||
}
|
||||
}
|
||||
|
||||
// serve listens on an existing PacketConn
|
||||
func (s *server) serve() {
|
||||
for {
|
||||
data := *getPacketBuffer()
|
||||
data = data[:protocol.MaxReceivePacketSize]
|
||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
n, remoteAddr, err := s.conn.ReadFrom(data)
|
||||
if err != nil {
|
||||
s.serverError = err
|
||||
close(s.errorChan)
|
||||
_ = s.Close()
|
||||
return
|
||||
}
|
||||
data = data[:n]
|
||||
if err := s.handlePacket(remoteAddr, data); err != nil {
|
||||
s.logger.Errorf("error handling packet: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accept returns newly openend sessions
|
||||
func (s *server) Accept() (Session, error) {
|
||||
var sess Session
|
||||
|
@ -297,10 +286,13 @@ func (s *server) Accept() (Session, error) {
|
|||
|
||||
// Close the server
|
||||
func (s *server) Close() error {
|
||||
s.sessionHandler.Close()
|
||||
err := s.conn.Close()
|
||||
<-s.errorChan // wait for serve() to return
|
||||
return err
|
||||
s.sessionHandler.CloseServer()
|
||||
// TODO: close the conn if this server was started with ListenAddr() (but not with Listen(net.PacketConn))
|
||||
if s.serverError == nil {
|
||||
s.serverError = errors.New("server closed")
|
||||
}
|
||||
close(s.errorChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Addr returns the server's network address
|
||||
|
@ -308,157 +300,65 @@ func (s *server) Addr() net.Addr {
|
|||
return s.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||
rcvTime := time.Now()
|
||||
|
||||
r := bytes.NewReader(packet)
|
||||
iHdr, err := wire.ParseInvariantHeader(r, s.config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
|
||||
}
|
||||
session, sessionKnown := s.sessionHandler.Get(iHdr.DestConnectionID)
|
||||
if sessionKnown && session == nil {
|
||||
// Late packet for closed session
|
||||
return nil
|
||||
}
|
||||
version := protocol.VersionUnknown
|
||||
if sessionKnown {
|
||||
version = session.GetVersion()
|
||||
}
|
||||
hdr, err := iHdr.Parse(r, protocol.PerspectiveClient, version)
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
|
||||
}
|
||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||
packetData := packet[len(packet)-r.Len():]
|
||||
|
||||
if hdr.IsPublicHeader {
|
||||
return s.handleGQUICPacket(session, hdr, packetData, remoteAddr, rcvTime)
|
||||
}
|
||||
return s.handleIETFQUICPacket(session, hdr, packetData, remoteAddr, rcvTime)
|
||||
func (s *server) closeWithError(e error) error {
|
||||
s.serverError = e
|
||||
return s.Close()
|
||||
}
|
||||
|
||||
func (s *server) handleIETFQUICPacket(
|
||||
session packetHandler,
|
||||
hdr *wire.Header,
|
||||
packetData []byte,
|
||||
remoteAddr net.Addr,
|
||||
rcvTime time.Time,
|
||||
) error {
|
||||
if hdr.IsLongHeader {
|
||||
if !s.supportsTLS {
|
||||
return errors.New("Received an IETF QUIC Long Header")
|
||||
}
|
||||
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
|
||||
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
|
||||
}
|
||||
packetData = packetData[:int(hdr.PayloadLen)]
|
||||
// TODO(#1312): implement parsing of compound packets
|
||||
|
||||
switch hdr.Type {
|
||||
case protocol.PacketTypeInitial:
|
||||
go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
|
||||
return nil
|
||||
case protocol.PacketTypeHandshake:
|
||||
// nothing to do here. Packet will be passed to the session.
|
||||
default:
|
||||
// Note that this also drops 0-RTT packets.
|
||||
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
|
||||
}
|
||||
func (s *server) handlePacket(p *receivedPacket) {
|
||||
if err := s.handlePacketImpl(p); err != nil {
|
||||
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
s.logger.Debugf("Received %s packet for unknown connection %s.", hdr.Type, hdr.DestConnectionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packetData,
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) handleGQUICPacket(
|
||||
session packetHandler,
|
||||
hdr *wire.Header,
|
||||
packetData []byte,
|
||||
remoteAddr net.Addr,
|
||||
rcvTime time.Time,
|
||||
) error {
|
||||
// ignore all Public Reset packets
|
||||
if hdr.ResetFlag {
|
||||
s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID)
|
||||
func (s *server) handlePacketImpl(p *receivedPacket) error {
|
||||
hdr := p.header
|
||||
version := hdr.Version
|
||||
|
||||
if hdr.Type == protocol.PacketTypeInitial {
|
||||
go s.serverTLS.HandleInitial(p.remoteAddr, hdr, p.data)
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionKnown := session != nil
|
||||
|
||||
// If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset
|
||||
// This should only happen after a server restart, when we still receive packets for connections that we lost the state for.
|
||||
if !sessionKnown && !hdr.VersionFlag {
|
||||
_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), remoteAddr)
|
||||
if !hdr.VersionFlag {
|
||||
_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
// a session is only created once the client sent a supported version
|
||||
// if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated
|
||||
// it is safe to drop it
|
||||
if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
||||
return nil
|
||||
// This is (potentially) a Client Hello.
|
||||
// Make sure it has the minimum required size before spending any more ressources on it.
|
||||
if len(p.data) < protocol.MinClientHelloSize {
|
||||
return errors.New("dropping small packet for unknown connection")
|
||||
}
|
||||
|
||||
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
||||
// since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet
|
||||
if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
||||
// drop packets that are too small to be valid first packets
|
||||
if len(packetData) < protocol.MinClientHelloSize {
|
||||
return errors.New("dropping small packet with unknown version")
|
||||
}
|
||||
s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version)
|
||||
_, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), remoteAddr)
|
||||
if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, version) {
|
||||
s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", version)
|
||||
_, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), p.remoteAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
if !sessionKnown {
|
||||
// This is (potentially) a Client Hello.
|
||||
// Make sure it has the minimum required size before spending any more ressources on it.
|
||||
if len(packetData) < protocol.MinClientHelloSize {
|
||||
return errors.New("dropping small packet for unknown connection")
|
||||
}
|
||||
|
||||
version := hdr.Version
|
||||
if !protocol.IsSupportedVersion(s.config.Versions, version) {
|
||||
return errors.New("Server BUG: negotiated version not supported")
|
||||
}
|
||||
|
||||
s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, remoteAddr)
|
||||
sess, err := s.newSession(
|
||||
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
||||
s.sessionRunner,
|
||||
version,
|
||||
hdr.DestConnectionID,
|
||||
s.scfg,
|
||||
s.tlsConf,
|
||||
s.config,
|
||||
s.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.sessionHandler.Add(hdr.DestConnectionID, sess)
|
||||
|
||||
go sess.run()
|
||||
session = sess
|
||||
if !protocol.IsSupportedVersion(s.config.Versions, version) {
|
||||
return errors.New("Server BUG: negotiated version not supported")
|
||||
}
|
||||
|
||||
session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packetData,
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, p.remoteAddr)
|
||||
sess, err := s.newSession(
|
||||
&conn{pconn: s.conn, currentAddr: p.remoteAddr},
|
||||
s.sessionRunner,
|
||||
version,
|
||||
hdr.DestConnectionID,
|
||||
s.scfg,
|
||||
s.tlsConf,
|
||||
s.config,
|
||||
s.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger))
|
||||
go sess.run()
|
||||
sess.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
|
|
63
server_session.go
Normal file
63
server_session.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type serverSession struct {
|
||||
quicSession
|
||||
|
||||
config *Config
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ packetHandler = &serverSession{}
|
||||
|
||||
func newServerSession(sess quicSession, config *Config, logger utils.Logger) packetHandler {
|
||||
return &serverSession{
|
||||
quicSession: sess,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handlePacket(p *receivedPacket) {
|
||||
if err := s.handlePacketImpl(p); err != nil {
|
||||
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
|
||||
hdr := p.header
|
||||
// ignore all Public Reset packets
|
||||
if hdr.ResetFlag {
|
||||
return fmt.Errorf("Received unexpected Public Reset for connection %s", hdr.DestConnectionID)
|
||||
}
|
||||
|
||||
// Probably an old packet that was sent by the client before the version was negotiated.
|
||||
// It is safe to drop it.
|
||||
if (hdr.VersionFlag || hdr.IsLongHeader) && hdr.Version != s.quicSession.GetVersion() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hdr.IsLongHeader {
|
||||
switch hdr.Type {
|
||||
case protocol.PacketTypeHandshake:
|
||||
// nothing to do here. Packet will be passed to the session.
|
||||
default:
|
||||
// Note that this also drops 0-RTT packets.
|
||||
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
|
||||
}
|
||||
}
|
||||
|
||||
s.quicSession.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serverSession) GetPerspective() protocol.Perspective {
|
||||
return protocol.PerspectiveServer
|
||||
}
|
101
server_session_test.go
Normal file
101
server_session_test.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Server Session", func() {
|
||||
var (
|
||||
qsess *MockQuicSession
|
||||
sess *serverSession
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
qsess = NewMockQuicSession(mockCtrl)
|
||||
sess = newServerSession(qsess, &Config{}, utils.DefaultLogger).(*serverSession)
|
||||
})
|
||||
|
||||
It("handles packets", func() {
|
||||
p := &receivedPacket{
|
||||
header: &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}},
|
||||
}
|
||||
qsess.EXPECT().handlePacket(p)
|
||||
sess.handlePacket(p)
|
||||
})
|
||||
|
||||
It("ignores Public Resets", func() {
|
||||
p := &receivedPacket{
|
||||
header: &wire.Header{
|
||||
ResetFlag: true,
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
},
|
||||
}
|
||||
err := sess.handlePacketImpl(p)
|
||||
Expect(err).To(MatchError("Received unexpected Public Reset for connection 0xdeadbeef"))
|
||||
})
|
||||
|
||||
It("ignores delayed packets with mismatching versions, for gQUIC", func() {
|
||||
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
|
||||
// don't EXPECT any calls to handlePacket()
|
||||
p := &receivedPacket{
|
||||
header: &wire.Header{
|
||||
VersionFlag: true,
|
||||
Version: protocol.VersionNumber(123),
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
},
|
||||
}
|
||||
err := sess.handlePacketImpl(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("ignores delayed packets with mismatching versions, for IETF QUIC", func() {
|
||||
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
|
||||
// don't EXPECT any calls to handlePacket()
|
||||
p := &receivedPacket{
|
||||
header: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Version: protocol.VersionNumber(123),
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
},
|
||||
}
|
||||
err := sess.handlePacketImpl(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("ignores packets with the wrong Long Header type", func() {
|
||||
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
|
||||
p := &receivedPacket{
|
||||
header: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketType0RTT,
|
||||
Version: protocol.VersionNumber(100),
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
},
|
||||
}
|
||||
err := sess.handlePacketImpl(p)
|
||||
Expect(err).To(MatchError("Received unsupported packet type: 0-RTT Protected"))
|
||||
})
|
||||
|
||||
It("passes on Handshake packets", func() {
|
||||
p := &receivedPacket{
|
||||
header: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
Version: protocol.VersionNumber(100),
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
},
|
||||
}
|
||||
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
|
||||
qsess.EXPECT().handlePacket(p)
|
||||
Expect(sess.handlePacketImpl(p)).To(Succeed())
|
||||
})
|
||||
|
||||
It("has the right perspective", func() {
|
||||
Expect(sess.GetPerspective()).To(Equal(protocol.PerspectiveServer))
|
||||
})
|
||||
})
|
302
server_test.go
302
server_test.go
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -27,6 +26,8 @@ type mockSession struct {
|
|||
runner sessionRunner
|
||||
}
|
||||
|
||||
func (s *mockSession) GetPerspective() protocol.Perspective { panic("not implemented") }
|
||||
|
||||
var _ = Describe("Server", func() {
|
||||
var (
|
||||
conn *mockPacketConn
|
||||
|
@ -89,7 +90,7 @@ var _ = Describe("Server", func() {
|
|||
Context("with mock session", func() {
|
||||
var (
|
||||
serv *server
|
||||
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
|
||||
firstPacket *receivedPacket
|
||||
connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||
sessions = make([]*MockQuicSession, 0)
|
||||
sessionHandler *MockPacketHandlerManager
|
||||
|
@ -126,9 +127,16 @@ var _ = Describe("Server", func() {
|
|||
serv.setup()
|
||||
b := &bytes.Buffer{}
|
||||
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]))
|
||||
firstPacket = []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||
firstPacket = append(append(firstPacket, b.Bytes()...), 0x01)
|
||||
firstPacket = append(firstPacket, bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)...) // add padding
|
||||
firstPacket = &receivedPacket{
|
||||
header: &wire.Header{
|
||||
VersionFlag: true,
|
||||
Version: serv.config.Versions[0],
|
||||
DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6},
|
||||
PacketNumber: 1,
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinClientHelloSize),
|
||||
rcvTime: time.Now(),
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -150,12 +158,10 @@ var _ = Describe("Server", func() {
|
|||
s.EXPECT().run().Do(func() { close(run) })
|
||||
sessions = append(sessions, s)
|
||||
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||
Expect(sess.(*mockSession).connID).To(Equal(connID))
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(cid protocol.ConnectionID, _ packetHandler) {
|
||||
Expect(cid).To(Equal(connID))
|
||||
})
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -165,7 +171,8 @@ var _ = Describe("Server", func() {
|
|||
err := serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
added := make(chan struct{})
|
||||
sessionHandler.EXPECT().Add(connID, sess).Do(func(protocol.ConnectionID, packetHandler) {
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, ph packetHandler) {
|
||||
Expect(ph.GetPerspective()).To(Equal(protocol.PerspectiveServer))
|
||||
close(added)
|
||||
})
|
||||
serv.serverTLS.sessionChan <- tlsSession{
|
||||
|
@ -184,17 +191,15 @@ var _ = Describe("Server", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := serv.Accept()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.(*mockSession).connID).To(Equal(connID))
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
sess.(*mockSession).runner.onHandshakeComplete(sess.(Session))
|
||||
sess.(*serverSession).quicSession.(*mockSession).runner.onHandshakeComplete(sess.(Session))
|
||||
})
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
err := serv.handlePacketImpl(firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(run).Should(BeClosed())
|
||||
|
@ -212,45 +217,20 @@ var _ = Describe("Server", func() {
|
|||
serv.Accept()
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(protocol.ConnectionID, packetHandler) {
|
||||
run <- errors.New("handshake error")
|
||||
})
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
sessionHandler.EXPECT().Close()
|
||||
close(serv.errorChan)
|
||||
serv.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("assigns packets to existing sessions", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
sess.EXPECT().GetVersion()
|
||||
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("closes the sessionHandler and the connection when Close is called", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.serve()
|
||||
}()
|
||||
// close the server
|
||||
sessionHandler.EXPECT().Close().AnyTimes()
|
||||
It("closes the sessionHandler when Close is called", func() {
|
||||
sessionHandler.EXPECT().CloseServer()
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Expect(conn.closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores packets for closed sessions", func() {
|
||||
sessionHandler.EXPECT().Get(connID).Return(nil, true)
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("works if no quic.Config is given", func(done Done) {
|
||||
|
@ -264,163 +244,56 @@ var _ = Describe("Server", func() {
|
|||
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var returned bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := ln.Accept()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
|
||||
returned = true
|
||||
}()
|
||||
ln.Close()
|
||||
Eventually(func() bool { return returned }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("errors when encountering a connection error", func() {
|
||||
testErr := errors.New("connection error")
|
||||
conn.readErr = testErr
|
||||
sessionHandler.EXPECT().Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.serve()
|
||||
ln.Accept()
|
||||
close(done)
|
||||
}()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
ln.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("ignores delayed packets with mismatching versions", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().GetVersion()
|
||||
// don't EXPECT any handlePacket() calls to this session
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
// add an unsupported version
|
||||
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1))
|
||||
data = append(append(data, b.Bytes()...), 0x01)
|
||||
err := serv.handlePacket(nil, data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
|
||||
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
|
||||
It("returns Accept when it is closed", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError("server closed"))
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().CloseServer()
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors on invalid public header", func() {
|
||||
err := serv.handlePacket(nil, nil)
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
|
||||
})
|
||||
|
||||
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
serv.supportsTLS = true
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
PayloadLen: 1000,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
||||
Expect(err).To(MatchError("packet payload (456 bytes) is smaller than the expected payload length (1000 bytes)"))
|
||||
})
|
||||
|
||||
It("cuts packets at the payload length", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
|
||||
Expect(packet.data).To(HaveLen(123))
|
||||
})
|
||||
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
serv.supportsTLS = true
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
PayloadLen: 123,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("drops packets with invalid packet types", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
serv.supportsTLS = true
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
PayloadLen: 123,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
||||
Expect(err).To(MatchError("Received unsupported packet type: Retry"))
|
||||
})
|
||||
|
||||
It("ignores Public Resets", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
It("returns Accept with the right error when closeWithError is called", func() {
|
||||
testErr := errors.New("connection error")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().CloseServer()
|
||||
serv.closeWithError(testErr)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
|
||||
config.Versions = []protocol.VersionNumber{99}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
VersionFlag: true,
|
||||
DestConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
p := &receivedPacket{
|
||||
header: &wire.Header{
|
||||
VersionFlag: true,
|
||||
DestConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
},
|
||||
data: make([]byte, protocol.MinClientHelloSize),
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
|
||||
serv.conn = conn
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
err := serv.handlePacket(nil, b.Bytes())
|
||||
Expect(serv.handlePacketImpl(p)).To(Succeed())
|
||||
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("doesn't respond with a version negotiation packet if the first packet is too small", func() {
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
VersionFlag: true,
|
||||
DestConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small
|
||||
serv.conn = conn
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
err := serv.handlePacket(udpAddr, b.Bytes())
|
||||
Expect(err).To(MatchError("dropping small packet with unknown version"))
|
||||
Expect(conn.dataWritten.Len()).Should(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -523,8 +396,11 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
|
||||
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
config.Versions = append(config.Versions, protocol.VersionTLS)
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
|
@ -536,13 +412,10 @@ var _ = Describe("Server", func() {
|
|||
Version: 0x1234,
|
||||
PayloadLen: protocol.MinInitialPacketSize,
|
||||
}
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)).To(Succeed())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) // add a fake CHLO
|
||||
conn.dataToRead <- b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -568,51 +441,6 @@ var _ = Describe("Server", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
|
||||
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
IsLongHeader: true,
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
PacketNumber: 0x55,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
|
||||
conn.dataToRead <- b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
|
||||
})
|
||||
|
||||
It("ignores non-Initial Long Header packets for unknown connections", func() {
|
||||
connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
IsLongHeader: true,
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
PacketNumber: 0x55,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.dataToRead <- b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
|
||||
})
|
||||
|
||||
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
|
||||
conn.dataReadFrom = udpAddr
|
||||
conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
type tlsSession struct {
|
||||
connID protocol.ConnectionID
|
||||
sess packetHandler
|
||||
sess quicSession
|
||||
}
|
||||
|
||||
type serverTLS struct {
|
||||
|
@ -126,7 +126,7 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, protocol.ConnectionID, error) {
|
||||
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (quicSession, protocol.ConnectionID, error) {
|
||||
if hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
return nil, nil, errors.New("dropping Initial packet with too short connection ID")
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
|
|||
return sess, connID, nil
|
||||
}
|
||||
|
||||
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, protocol.ConnectionID, error) {
|
||||
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (quicSession, protocol.ConnectionID, error) {
|
||||
version := hdr.Version
|
||||
bc := handshake.NewCryptoStreamConn(remoteAddr)
|
||||
bc.AddDataForReading(frame.Data)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue