use a mock net.PacketConn in tests

This commit is contained in:
Marten Seemann 2020-09-27 14:42:11 +07:00
parent ebe051b2cc
commit a65274942c
9 changed files with 577 additions and 442 deletions

View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
@ -12,7 +11,6 @@ import (
mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/quictrace"
@ -25,7 +23,7 @@ import (
var _ = Describe("Client", func() {
var (
cl *client
packetConn *mockPacketConn
packetConn *MockPacketConn
addr net.Addr
connID protocol.ConnectionID
mockMultiplexer *MockMultiplexer
@ -51,17 +49,6 @@ var _ = Describe("Client", func() {
) quicSession
)
// generate a packet sent by the server that accepts the QUIC version suggested by the client
acceptClientVersionPacket := func(connID protocol.ConnectionID) []byte {
b := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 1,
PacketNumberLen: 1,
}).Write(b, protocol.VersionWhatever)).To(Succeed())
return b.Bytes()
}
BeforeEach(func() {
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
@ -73,9 +60,8 @@ var _ = Describe("Client", func() {
Eventually(areSessionsRunning).Should(BeFalse())
// sess = NewMockQuicSession(mockCtrl)
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = newMockPacketConn()
packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
packetConn.dataReadFrom = addr
packetConn = NewMockPacketConn(mockCtrl)
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
cl = &client{
srcConnID: connID,
destConnID: connID,
@ -221,7 +207,7 @@ var _ = Describe("Client", func() {
sess.EXPECT().run()
return sess
}
tracer.EXPECT().StartedConnection(packetConn.addr, addr, protocol.VersionTLS, gomock.Any(), gomock.Any())
tracer.EXPECT().StartedConnection(packetConn.LocalAddr(), addr, protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial(
packetConn,
addr,
@ -350,7 +336,6 @@ var _ = Describe("Client", func() {
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess
}
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial(
packetConn,

View file

@ -4,16 +4,23 @@ import (
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Basic Conn Test", func() {
It("reads a packet", func() {
c := newMockPacketConn()
c := NewMockPacketConn(mockCtrl)
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
c.dataReadFrom = addr
c.dataToRead <- []byte("foobar")
c.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
data := []byte("foobar")
Expect(b).To(HaveLen(int(protocol.MaxReceivePacketSize)))
return copy(b, data), addr, nil
})
conn, err := wrapConn(c)
Expect(err).ToNot(HaveOccurred())

137
mock_packetconn_test.go Normal file
View file

@ -0,0 +1,137 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: net (interfaces: PacketConn)
// Package quic is a generated GoMock package.
package quic
import (
net "net"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockPacketConn is a mock of PacketConn interface
type MockPacketConn struct {
ctrl *gomock.Controller
recorder *MockPacketConnMockRecorder
}
// MockPacketConnMockRecorder is the mock recorder for MockPacketConn
type MockPacketConnMockRecorder struct {
mock *MockPacketConn
}
// NewMockPacketConn creates a new mock instance
func NewMockPacketConn(ctrl *gomock.Controller) *MockPacketConn {
mock := &MockPacketConn{ctrl: ctrl}
mock.recorder = &MockPacketConnMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockPacketConn) EXPECT() *MockPacketConnMockRecorder {
return m.recorder
}
// Close mocks base method
func (m *MockPacketConn) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockPacketConnMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketConn)(nil).Close))
}
// LocalAddr mocks base method
func (m *MockPacketConn) LocalAddr() net.Addr {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LocalAddr")
ret0, _ := ret[0].(net.Addr)
return ret0
}
// LocalAddr indicates an expected call of LocalAddr
func (mr *MockPacketConnMockRecorder) LocalAddr() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPacketConn)(nil).LocalAddr))
}
// ReadFrom mocks base method
func (m *MockPacketConn) ReadFrom(arg0 []byte) (int, net.Addr, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadFrom", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(net.Addr)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// ReadFrom indicates an expected call of ReadFrom
func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConn)(nil).ReadFrom), arg0)
}
// SetDeadline mocks base method
func (m *MockPacketConn) SetDeadline(arg0 time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetDeadline indicates an expected call of SetDeadline
func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetDeadline), arg0)
}
// SetReadDeadline mocks base method
func (m *MockPacketConn) SetReadDeadline(arg0 time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetReadDeadline indicates an expected call of SetReadDeadline
func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetReadDeadline), arg0)
}
// SetWriteDeadline mocks base method
func (m *MockPacketConn) SetWriteDeadline(arg0 time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetWriteDeadline indicates an expected call of SetWriteDeadline
func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetWriteDeadline), arg0)
}
// WriteTo mocks base method
func (m *MockPacketConn) WriteTo(arg0 []byte, arg1 net.Addr) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WriteTo", arg0, arg1)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// WriteTo indicates an expected call of WriteTo
func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConn)(nil).WriteTo), arg0, arg1)
}

View file

@ -21,3 +21,4 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager"
//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer"
//go:generate sh -c "mockgen -package quic -self_package github.com/lucas-clemente/quic-go -destination mock_token_store_test.go github.com/lucas-clemente/quic-go TokenStore && goimports -w mock_token_store_test.go"
//go:generate sh -c "mockgen -package quic -self_package github.com/lucas-clemente/quic-go -destination mock_packetconn_test.go net PacketConn && goimports -w mock_packetconn_test.go"

View file

@ -64,7 +64,8 @@ func (m *connMultiplexer) AddConn(
m.mutex.Lock()
defer m.mutex.Unlock()
connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String()
addr := c.LocalAddr()
connIndex := addr.Network() + " " + addr.String()
p, ok := m.conns[connIndex]
if !ok {
manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)

View file

@ -3,6 +3,7 @@ package quic
import (
"net"
"github.com/golang/mock/gomock"
mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging"
. "github.com/onsi/ginkgo"
@ -14,16 +15,19 @@ type testConn struct {
net.PacketConn
}
var _ = Describe("Client Multiplexer", func() {
var _ = Describe("Multiplexer", func() {
It("adds a new packet conn ", func() {
conn := newMockPacketConn()
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
_, err := getMultiplexer().AddConn(conn, 8, nil, nil)
Expect(err).ToNot(HaveOccurred())
})
It("recognizes when the same connection is added twice", func() {
pconn := newMockPacketConn()
pconn.addr = &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
pconn := NewMockPacketConn(mockCtrl)
pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn := testConn{PacketConn: pconn}
tracer := mocklogging.NewMockTracer(mockCtrl)
_, err := getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer)
@ -35,7 +39,9 @@ var _ = Describe("Client Multiplexer", func() {
})
It("errors when adding an existing conn with a different connection ID length", func() {
conn := newMockPacketConn()
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
_, err := getMultiplexer().AddConn(conn, 5, nil, nil)
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 6, nil, nil)
@ -43,7 +49,9 @@ var _ = Describe("Client Multiplexer", func() {
})
It("errors when adding an existing conn with a different stateless rest key", func() {
conn := newMockPacketConn()
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
_, err := getMultiplexer().AddConn(conn, 7, []byte("foobar"), nil)
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 7, []byte("raboof"), nil)
@ -51,7 +59,9 @@ var _ = Describe("Client Multiplexer", func() {
})
It("errors when adding an existing conn with different tracers", func() {
conn := newMockPacketConn()
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
_, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))

View file

@ -20,10 +20,17 @@ import (
)
var _ = Describe("Packet Handler Map", func() {
type packetToRead struct {
addr net.Addr
data []byte
err error
}
var (
handler *packetHandlerMap
conn *mockPacketConn
tracer *mocklogging.MockTracer
handler *packetHandlerMap
conn *MockPacketConn
tracer *mocklogging.MockTracer
packetChan chan packetToRead
connIDLen int
statelessResetKey []byte
@ -52,28 +59,24 @@ var _ = Describe("Packet Handler Map", func() {
statelessResetKey = nil
connIDLen = 0
tracer = mocklogging.NewMockTracer(mockCtrl)
packetChan = make(chan packetToRead, 10)
})
JustBeforeEach(func() {
conn = newMockPacketConn()
conn = NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
p, ok := <-packetChan
if !ok {
return 0, nil, errors.New("closed")
}
return copy(b, p.data), p.addr, p.err
}).AnyTimes()
phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger)
Expect(err).ToNot(HaveOccurred())
handler = phm.(*packetHandlerMap)
})
AfterEach(func() {
// delete sessions and the server before closing
// They might be mock implementations, and we'd have to register the expected calls before otherwise.
handler.mutex.Lock()
for connID := range handler.handlers {
delete(handler.handlers, connID)
}
handler.server = nil
handler.mutex.Unlock()
handler.Destroy()
Eventually(handler.listening).Should(BeClosed())
})
It("closes", func() {
getMultiplexer() // make the sync.Once execute
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
@ -94,284 +97,307 @@ var _ = Describe("Packet Handler Map", func() {
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2)
mockMultiplexer.EXPECT().RemoveConn(gomock.Any())
handler.close(testErr)
close(packetChan)
Eventually(handler.listening).Should(BeClosed())
})
Context("handling packets", func() {
BeforeEach(func() {
connIDLen = 5
Context("other operations", func() {
AfterEach(func() {
// delete sessions and the server before closing
// They might be mock implementations, and we'd have to register the expected calls before otherwise.
handler.mutex.Lock()
for connID := range handler.handlers {
delete(handler.handlers, connID)
}
handler.server = nil
handler.mutex.Unlock()
conn.EXPECT().Close().MaxTimes(1)
close(packetChan)
handler.Destroy()
Eventually(handler.listening).Should(BeClosed())
})
It("handles packets for different packet handlers on the same packet conn", func() {
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
packetHandler1 := NewMockPacketHandler(mockCtrl)
packetHandler2 := NewMockPacketHandler(mockCtrl)
handledPacket1 := make(chan struct{})
handledPacket2 := make(chan struct{})
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID1))
close(handledPacket1)
Context("handling packets", func() {
BeforeEach(func() {
connIDLen = 5
})
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID2))
close(handledPacket2)
})
handler.Add(connID1, packetHandler1)
handler.Add(connID2, packetHandler2)
conn.dataToRead <- getPacket(connID1)
conn.dataToRead <- getPacket(connID2)
Eventually(handledPacket1).Should(BeClosed())
Eventually(handledPacket2).Should(BeClosed())
})
It("drops unparseable packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: []byte{0, 1, 2, 3},
})
})
It("deletes removed sessions immediately", func() {
handler.deleteRetiredSessionsAfter = time.Hour
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Remove(connID)
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("deletes retired session entries after a wait time", func() {
handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess := NewMockPacketHandler(mockCtrl)
handler.Add(connID, sess)
handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond))
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("passes packets arriving late for closed sessions to that session", func() {
handler.deleteRetiredSessionsAfter = time.Hour
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockPacketHandler(mockCtrl)
handled := make(chan struct{})
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
close(handled)
})
handler.Add(connID, packetHandler)
handler.Retire(connID)
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
Eventually(handled).Should(BeClosed())
})
It("drops packets for unknown receivers", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
})
It("closes the packet handlers when reading from the conn fails", func() {
done := make(chan struct{})
packetHandler := NewMockPacketHandler(mockCtrl)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
Expect(e).To(HaveOccurred())
close(done)
})
handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
conn.Close()
Eventually(done).Should(BeClosed())
})
It("says if a connection ID is already taken", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
})
It("says if a connection ID is already taken, for AddWithConnID", func() {
clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
newConnID1 := protocol.ConnectionID{1, 2, 3, 4}
newConnID2 := protocol.ConnectionID{4, 3, 2, 1}
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
})
})
Context("running a server", func() {
It("adds a server", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
cid, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(cid).To(Equal(connID))
})
handler.SetServer(server)
handler.handlePacket(&receivedPacket{data: p})
})
It("closes all server sessions", func() {
clientSess := NewMockPacketHandler(mockCtrl)
clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
serverSess := NewMockPacketHandler(mockCtrl)
serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
serverSess.EXPECT().shutdown()
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess)
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess)
handler.CloseServer()
})
It("stops handling packets with unknown connection IDs after the server is closed", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
// don't EXPECT any calls to server.handlePacket
handler.SetServer(server)
handler.CloseServer()
handler.handlePacket(&receivedPacket{data: p})
})
})
Context("stateless resets", func() {
BeforeEach(func() {
connIDLen = 5
})
Context("handling", func() {
It("handles stateless resets", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
destroyed := make(chan struct{})
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
var resetErr statelessResetErr
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.token).To(Equal(token))
close(destroyed)
It("handles packets for different packet handlers on the same packet conn", func() {
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
packetHandler1 := NewMockPacketHandler(mockCtrl)
packetHandler2 := NewMockPacketHandler(mockCtrl)
handledPacket1 := make(chan struct{})
handledPacket2 := make(chan struct{})
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID1))
close(handledPacket1)
})
conn.dataToRead <- packet
Eventually(destroyed).Should(BeClosed())
})
It("handles stateless resets for 0-length connection IDs", func() {
handler.connIDLen = 0
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
destroyed := make(chan struct{})
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
var resetErr statelessResetErr
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.token).To(Equal(token))
close(destroyed)
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID2))
close(handledPacket2)
})
conn.dataToRead <- packet
Eventually(destroyed).Should(BeClosed())
handler.Add(connID1, packetHandler1)
handler.Add(connID2, packetHandler2)
packetChan <- packetToRead{data: getPacket(connID1)}
packetChan <- packetToRead{data: getPacket(connID2)}
Eventually(handledPacket1).Should(BeClosed())
Eventually(handledPacket2).Should(BeClosed())
})
It("retires reset tokens", func() {
It("drops unparseable packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: []byte{0, 1, 2, 3},
})
})
It("deletes removed sessions immediately", func() {
handler.deleteRetiredSessionsAfter = time.Hour
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Remove(connID)
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("deletes retired session entries after a wait time", func() {
handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42}
packetHandler := NewMockPacketHandler(mockCtrl)
handler.Add(connID, packetHandler)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, NewMockPacketHandler(mockCtrl))
handler.RetireResetToken(token)
packetHandler.EXPECT().handlePacket(gomock.Any())
p := append([]byte{0x40} /* short header packet */, connID.Bytes()...)
p = append(p, make([]byte, 50)...)
p = append(p, token[:]...)
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess := NewMockPacketHandler(mockCtrl)
handler.Add(connID, sess)
handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond))
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("passes packets arriving late for closed sessions to that session", func() {
handler.deleteRetiredSessionsAfter = time.Hour
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockPacketHandler(mockCtrl)
handled := make(chan struct{})
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
close(handled)
})
handler.Add(connID, packetHandler)
handler.Retire(connID)
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
Eventually(handled).Should(BeClosed())
})
It("drops packets for unknown receivers", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
})
It("closes the packet handlers when reading from the conn fails", func() {
done := make(chan struct{})
packetHandler := NewMockPacketHandler(mockCtrl)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
Expect(e).To(HaveOccurred())
close(done)
})
handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
packetChan <- packetToRead{err: errors.New("read failed")}
Eventually(done).Should(BeClosed())
})
It("says if a connection ID is already taken", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
})
It("says if a connection ID is already taken, for AddWithConnID", func() {
clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
newConnID1 := protocol.ConnectionID{1, 2, 3, 4}
newConnID2 := protocol.ConnectionID{4, 3, 2, 1}
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
})
})
Context("running a server", func() {
It("adds a server", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
cid, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(cid).To(Equal(connID))
})
handler.SetServer(server)
handler.handlePacket(&receivedPacket{data: p})
})
It("ignores packets too small to contain a stateless reset", func() {
handler.connIDLen = 0
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
packet := append([]byte{0x40} /* short header packet */, token[:15]...)
done := make(chan struct{})
// don't EXPECT any calls here, but register the closing of the done channel
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) {
close(done)
}).AnyTimes()
conn.dataToRead <- packet
Consistently(done).ShouldNot(BeClosed())
It("closes all server sessions", func() {
clientSess := NewMockPacketHandler(mockCtrl)
clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
serverSess := NewMockPacketHandler(mockCtrl)
serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
serverSess.EXPECT().shutdown()
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess)
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess)
handler.CloseServer()
})
It("stops handling packets with unknown connection IDs after the server is closed", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
// don't EXPECT any calls to server.handlePacket
handler.SetServer(server)
handler.CloseServer()
handler.handlePacket(&receivedPacket{data: p})
})
})
Context("generating", func() {
Context("stateless resets", func() {
BeforeEach(func() {
key := make([]byte, 32)
rand.Read(key)
statelessResetKey = key
connIDLen = 5
})
It("generates stateless reset tokens", func() {
connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2)))
})
It("sends stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, 100)...)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
Context("handling", func() {
It("handles stateless resets", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
destroyed := make(chan struct{})
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
defer close(destroyed)
Expect(err).To(HaveOccurred())
var resetErr statelessResetErr
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.token).To(Equal(token))
})
packetChan <- packetToRead{data: packet}
Eventually(destroyed).Should(BeClosed())
time.Sleep(time.Second)
})
It("handles stateless resets for 0-length connection IDs", func() {
handler.connIDLen = 0
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
destroyed := make(chan struct{})
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
var resetErr statelessResetErr
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.token).To(Equal(token))
close(destroyed)
})
packetChan <- packetToRead{data: packet}
Eventually(destroyed).Should(BeClosed())
})
It("retires reset tokens", func() {
handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42}
packetHandler := NewMockPacketHandler(mockCtrl)
handler.Add(connID, packetHandler)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, NewMockPacketHandler(mockCtrl))
handler.RetireResetToken(token)
packetHandler.EXPECT().handlePacket(gomock.Any())
p := append([]byte{0x40} /* short header packet */, connID.Bytes()...)
p = append(p, make([]byte, 50)...)
p = append(p, token[:]...)
time.Sleep(scaleDuration(30 * time.Millisecond))
handler.handlePacket(&receivedPacket{data: p})
})
It("ignores packets too small to contain a stateless reset", func() {
handler.connIDLen = 0
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
done := make(chan struct{})
// don't EXPECT any calls here, but register the closing of the done channel
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) {
close(done)
}).AnyTimes()
packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)}
Consistently(done).ShouldNot(BeClosed())
})
var reset mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&reset))
Expect(reset.to).To(Equal(addr))
Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet
Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize))
})
It("doesn't send stateless resets for small packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
Context("generating", func() {
BeforeEach(func() {
key := make([]byte, 32)
rand.Read(key)
statelessResetKey = key
})
Consistently(conn.dataWritten).ShouldNot(Receive())
})
})
Context("if no key is configured", func() {
It("doesn't send stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, 100)...)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
It("generates stateless reset tokens", func() {
connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2)))
})
It("sends stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, 100)...)
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) {
defer close(done)
Expect(b[0] & 0x80).To(BeZero()) // short header packet
Expect(b).To(HaveLen(protocol.MinStatelessResetSize))
})
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
})
Eventually(done).Should(BeClosed())
})
It("doesn't send stateless resets for small packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
})
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
})
Context("if no key is configured", func() {
It("doesn't send stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, 100)...)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
})
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
Consistently(conn.dataWritten).ShouldNot(Receive())
})
})
})

View file

@ -1,90 +1,28 @@
package quic
import (
"errors"
"net"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockPacketConnWrite struct {
data []byte
to net.Addr
}
type mockPacketConn struct {
addr net.Addr
dataToRead chan []byte
dataReadFrom net.Addr
readErr error
dataWritten chan mockPacketConnWrite
closed bool
}
func newMockPacketConn() *mockPacketConn {
return &mockPacketConn{
addr: &net.UDPAddr{IP: net.IPv6zero, Port: 0x42},
dataToRead: make(chan []byte, 1000),
dataWritten: make(chan mockPacketConnWrite, 1000),
}
}
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
if c.readErr != nil {
return 0, nil, c.readErr
}
data, ok := <-c.dataToRead
if !ok {
return 0, nil, errors.New("connection closed")
}
n := copy(b, data)
return n, c.dataReadFrom, nil
}
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
select {
case c.dataWritten <- mockPacketConnWrite{to: addr, data: b}:
return len(b), nil
default:
panic("channel full")
}
}
func (c *mockPacketConn) Close() error {
if !c.closed {
close(c.dataToRead)
}
c.closed = true
return nil
}
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }
var _ net.PacketConn = &mockPacketConn{}
var _ = Describe("Send-Connection", func() {
var c sendConn
var packetConn *mockPacketConn
var _ = Describe("Connection (for sending packets)", func() {
var (
c sendConn
packetConn *MockPacketConn
addr net.Addr
)
BeforeEach(func() {
addr := &net.UDPAddr{
IP: net.IPv4(192, 168, 100, 200),
Port: 1337,
}
packetConn = newMockPacketConn()
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = NewMockPacketConn(mockCtrl)
c = newSendConn(packetConn, addr)
})
It("writes", func() {
packetConn.EXPECT().WriteTo([]byte("foobar"), addr)
Expect(c.Write([]byte("foobar"))).To(Succeed())
var write mockPacketConnWrite
Expect(packetConn.dataWritten).To(Receive(&write))
Expect(write.to.String()).To(Equal("192.168.100.200:1337"))
Expect(write.data).To(Equal([]byte("foobar")))
})
It("gets the remote address", func() {
@ -96,13 +34,12 @@ var _ = Describe("Send-Connection", func() {
IP: net.IPv4(192, 168, 0, 1),
Port: 1234,
}
packetConn.addr = addr
packetConn.EXPECT().LocalAddr().Return(addr)
Expect(c.LocalAddr()).To(Equal(addr))
})
It("closes", func() {
err := c.Close()
Expect(err).ToNot(HaveOccurred())
Expect(packetConn.closed).To(BeTrue())
packetConn.EXPECT().Close()
Expect(c.Close()).To(Succeed())
})
})

View file

@ -38,7 +38,7 @@ func areServersRunning() bool {
var _ = Describe("Server", func() {
var (
conn *mockPacketConn
conn *MockPacketConn
tlsConf *tls.Config
)
@ -97,8 +97,9 @@ var _ = Describe("Server", func() {
}
BeforeEach(func() {
conn = newMockPacketConn()
conn.addr = &net.UDPAddr{}
conn = NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1)
tlsConf = testdata.GetTLSConfig()
tlsConf.NextProtos = []string{"proto1"}
})
@ -212,7 +213,8 @@ var _ = Describe("Server", func() {
}, nil)
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
It("drops too small Initial", func() {
@ -225,7 +227,8 @@ var _ = Describe("Server", func() {
)
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
It("drops non-Initial packets", func() {
@ -236,7 +239,8 @@ var _ = Describe("Server", func() {
}, []byte("invalid"))
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
It("decodes the token from the Token field", func() {
@ -260,6 +264,7 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
@ -284,6 +289,7 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
@ -360,8 +366,9 @@ var _ = Describe("Server", func() {
go func() {
defer GinkgoRecover()
serv.handlePacket(p)
// the Handshake packet is written by the session
Consistently(conn.dataWritten).ShouldNot(Receive())
// the Handshake packet is written by the session.
// Make sure there are no Write calls on the packet conn.
time.Sleep(50 * time.Millisecond)
close(done)
}()
// make sure we're using a server-generated connection ID
@ -379,23 +386,27 @@ var _ = Describe("Server", func() {
DestConnectionID: destConnID,
Version: 0x42,
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
Expect(replyHdr.IsLongHeader).To(BeTrue())
Expect(replyHdr.Version).To(BeZero())
Expect(replyHdr.SrcConnectionID).To(Equal(destConnID))
Expect(replyHdr.DestConnectionID).To(Equal(srcConnID))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue())
hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(b))
Expect(err).ToNot(HaveOccurred())
Expect(hdr.DestConnectionID).To(Equal(srcConnID))
Expect(hdr.SrcConnectionID).To(Equal(destConnID))
Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42)))
return len(b), nil
})
serv.handlePacket(packet)
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
Expect(wire.IsVersionNegotiationPacket(write.data)).To(BeTrue())
hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(write.data))
Expect(err).ToNot(HaveOccurred())
Expect(hdr.DestConnectionID).To(Equal(srcConnID))
Expect(hdr.SrcConnectionID).To(Equal(destConnID))
Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42)))
Eventually(done).Should(BeClosed())
})
It("replies with a Retry packet, if a Token is required", func() {
@ -408,23 +419,27 @@ var _ = Describe("Server", func() {
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.Token).ToNot(BeEmpty())
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.Token).ToNot(BeEmpty())
Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID)[:]))
return len(b), nil
})
serv.handlePacket(packet)
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
replyHdr := parseHeader(write.data)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.Token).ToNot(BeEmpty())
Expect(write.data[len(write.data)-16:]).To(Equal(handshake.GetRetryIntegrityTag(write.data[:len(write.data)-16], hdr.DestConnectionID)[:]))
Eventually(done).Should(BeClosed())
})
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
@ -441,7 +456,8 @@ var _ = Describe("Server", func() {
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
@ -452,25 +468,28 @@ var _ = Describe("Server", func() {
Expect(ccf.IsApplicationError).To(BeFalse())
Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient)
extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version)
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())
f, err := wire.NewFrameParser(hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := f.(*wire.ConnectionCloseFrame)
Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken))
Expect(ccf.ReasonPhrase).To(BeEmpty())
return len(b), nil
})
serv.handlePacket(packet)
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
replyHdr := parseHeader(write.data)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient)
extHdr, err := unpackHeader(opener, replyHdr, write.data, hdr.Version)
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, write.data[extHdr.ParsedLen():], extHdr.PacketNumber, write.data[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())
f, err := wire.NewFrameParser(hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := f.(*wire.ConnectionCloseFrame)
Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken))
Expect(ccf.ReasonPhrase).To(BeEmpty())
Eventually(done).Should(BeClosed())
})
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
@ -490,7 +509,8 @@ var _ = Describe("Server", func() {
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError)
serv.handlePacket(packet)
Consistently(conn.dataWritten).ShouldNot(Receive())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
It("creates a session, if no Token is required", func() {
@ -559,7 +579,8 @@ var _ = Describe("Server", func() {
defer GinkgoRecover()
serv.handlePacket(p)
// the Handshake packet is written by the session
Consistently(conn.dataWritten).ShouldNot(Receive())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
close(done)
}()
// make sure we're using a server-generated connection ID
@ -756,7 +777,8 @@ var _ = Describe("Server", func() {
defer GinkgoRecover()
defer wg.Done()
serv.handlePacket(getInitialWithRandomDestConnID())
Consistently(conn.dataWritten).ShouldNot(Receive())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
}()
}
wg.Wait()
@ -764,15 +786,18 @@ var _ = Describe("Server", func() {
hdr, _, _, err := wire.ParsePacket(p.data, 0)
Expect(err).ToNot(HaveOccurred())
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
rejectHdr := parseHeader(b)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(rejectHdr.Version).To(Equal(hdr.Version))
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
return len(b), nil
})
serv.handlePacket(p)
var reject mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&reject))
Expect(reject.to).To(Equal(p.remoteAddr))
rejectHdr := parseHeader(reject.data)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(rejectHdr.Version).To(Equal(hdr.Version))
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Eventually(done).Should(BeClosed())
})
It("doesn't accept new sessions if they were closed in the mean time", func() {
@ -817,7 +842,8 @@ var _ = Describe("Server", func() {
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any())
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
Eventually(sessionCreated).Should(BeClosed())
cancel()
time.Sleep(scaleDuration(200 * time.Millisecond))
@ -1034,19 +1060,23 @@ var _ = Describe("Server", func() {
}
Eventually(func() int32 { return atomic.LoadInt32(&serv.sessionQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize))
Consistently(conn.dataWritten).ShouldNot(Receive())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
p := getInitialWithRandomDestConnID()
hdr := parseHeader(p.data)
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
rejectHdr := parseHeader(b)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(rejectHdr.Version).To(Equal(hdr.Version))
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
return len(b), nil
})
serv.handlePacket(p)
var reject mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&reject))
Expect(reject.to).To(Equal(senderAddr))
rejectHdr := parseHeader(reject.data)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(rejectHdr.Version).To(Equal(hdr.Version))
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Eventually(done).Should(BeClosed())
})
It("doesn't accept new sessions if they were closed in the mean time", func() {
@ -1087,7 +1117,8 @@ var _ = Describe("Server", func() {
return true
})
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
Eventually(sessionCreated).Should(BeClosed())
cancel()
time.Sleep(scaleDuration(200 * time.Millisecond))