uquic/transport_test.go
Gaukas Wang 4973374ea5
sync: quic-go 0.42.0
Signed-off-by: Gaukas Wang <i@gaukas.wang>
2024-04-23 22:34:55 -06:00

444 lines
13 KiB
Go

package quic
import (
"bytes"
"context"
"crypto/rand"
"errors"
"net"
"syscall"
"time"
tls "github.com/refraction-networking/utls"
mocklogging "github.com/refraction-networking/uquic/internal/mocks/logging"
"github.com/refraction-networking/uquic/internal/protocol"
"github.com/refraction-networking/uquic/internal/wire"
"github.com/refraction-networking/uquic/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"go.uber.org/mock/gomock"
)
var _ = Describe("Transport", func() {
type packetToRead struct {
addr net.Addr
data []byte
err error
}
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
b, err := (&wire.ExtendedHeader{
Header: wire.Header{
Type: t,
DestConnectionID: connID,
Length: length,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen2,
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
return b
}
getPacket := func(connID protocol.ConnectionID) []byte {
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
}
newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
p, ok := <-packetChan
if !ok {
return 0, nil, errors.New("closed")
}
return copy(b, p.data), p.addr, p.err
}).AnyTimes()
// for shutdown
conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
return conn
}
It("handles packets for different packet handlers on the same packet conn", func() {
packetChan := make(chan packetToRead)
tr := &Transport{Conn: newMockPacketConn(packetChan)}
tr.init(true)
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
handled := make(chan struct{}, 2)
phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
h := NewMockPacketHandler(mockCtrl)
h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
defer GinkgoRecover()
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID1))
handled <- struct{}{}
})
return h, true
})
phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
h := NewMockPacketHandler(mockCtrl)
h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
defer GinkgoRecover()
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID2))
handled <- struct{}{}
})
return h, true
})
packetChan <- packetToRead{data: getPacket(connID1)}
packetChan <- packetToRead{data: getPacket(connID2)}
Eventually(handled).Should(Receive())
Eventually(handled).Should(Receive())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("closes listeners", func() {
packetChan := make(chan packetToRead)
tr := &Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
ln, err := tr.Listen(&tls.Config{}, nil)
Expect(err).ToNot(HaveOccurred())
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
Expect(ln.Close()).To(Succeed())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("closes transport concurrently with listener", func() {
// try 10 times to trigger race conditions
for i := 0; i < 10; i++ {
packetChan := make(chan packetToRead)
tr := &Transport{Conn: newMockPacketConn(packetChan)}
ln, err := tr.Listen(&tls.Config{}, nil)
Expect(err).ToNot(HaveOccurred())
ch := make(chan bool)
// Close transport and listener concurrently.
go func() {
ch <- true
Expect(ln.Close()).To(Succeed())
ch <- true
}()
<-ch
close(packetChan)
Expect(tr.Close()).To(Succeed())
<-ch
}
})
It("drops unparseable QUIC packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
packetChan := make(chan packetToRead)
t, tracer := mocklogging.NewMockTracer(mockCtrl)
tr := &Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: 10,
Tracer: t,
}
tr.init(true)
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
packetChan <- packetToRead{
addr: addr,
data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3},
}
Eventually(dropped).Should(BeClosed())
// shutdown
tracer.EXPECT().Close()
close(packetChan)
tr.Close()
})
It("closes when reading from the conn fails", func() {
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(true)
tr.handlerMap = phm
done := make(chan struct{})
phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
packetChan <- packetToRead{err: errors.New("read failed")}
Eventually(done).Should(BeClosed())
// shutdown
close(packetChan)
tr.Close()
})
It("continues listening after temporary errors", func() {
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(true)
tr.handlerMap = phm
tempErr := deadlineError{}
Expect(tempErr.Temporary()).To(BeTrue())
packetChan <- packetToRead{err: tempErr}
// don't expect any calls to phm.Close
time.Sleep(50 * time.Millisecond)
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("handles short header packets resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
tr := Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: connID.Len(),
}
tr.init(true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
var token protocol.StatelessResetToken
rand.Read(token[:])
var b []byte
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
b = append(b, token[:]...)
conn := NewMockPacketHandler(mockCtrl)
gomock.InOrder(
phm.EXPECT().Get(connID).Return(conn, true),
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
Expect(p.data).To(Equal(b))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
}),
)
packetChan <- packetToRead{data: b}
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("handles stateless resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
tr := Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: connID.Len(),
}
tr.init(true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
var token protocol.StatelessResetToken
rand.Read(token[:])
var b []byte
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
b = append(b, token[:]...)
conn := NewMockPacketHandler(mockCtrl)
destroyed := make(chan struct{})
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().GetByResetToken(token).Return(conn, true),
conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
Expect(err).To(MatchError(&StatelessResetError{Token: token}))
close(destroyed)
}),
)
packetChan <- packetToRead{data: b}
Eventually(destroyed).Should(BeClosed())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("sends stateless resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
conn := newMockPacketConn(packetChan)
tr := Transport{
Conn: conn,
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
ConnectionIDLength: connID.Len(),
}
tr.init(true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
var b []byte
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
var token protocol.StatelessResetToken
rand.Read(token[:])
written := make(chan struct{})
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().GetByResetToken(gomock.Any()),
phm.EXPECT().GetStatelessResetToken(connID).Return(token),
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) (int, error) {
defer close(written)
Expect(bytes.Contains(b, token[:])).To(BeTrue())
return len(b), nil
}),
)
packetChan <- packetToRead{data: b}
Eventually(written).Should(BeClosed())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("closes uninitialized Transport and closes underlying PacketConn", func() {
packetChan := make(chan packetToRead)
pconn := newMockPacketConn(packetChan)
tr := &Transport{
Conn: pconn,
createdConn: true, // owns pconn
}
// NO init
// shutdown
close(packetChan)
pconn.EXPECT().Close()
Expect(tr.Close()).To(Succeed())
})
It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() {
packetChan := make(chan packetToRead)
pconn := newMockPacketConn(packetChan)
syscallconn := &mockSyscallConn{pconn}
tr := &Transport{
Conn: syscallconn,
}
err := tr.init(false)
Expect(err).To(HaveOccurred())
conns := getMultiplexer().(*connMultiplexer).conns
Expect(len(conns)).To(BeZero())
})
It("allows receiving non-QUIC packets", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
packetChan := make(chan packetToRead)
tr := &Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: 10,
}
tr.init(true)
receivedPacketChan := make(chan []byte)
go func() {
defer GinkgoRecover()
b := make([]byte, 100)
n, addr, err := tr.ReadNonQUICPacket(context.Background(), b)
Expect(err).ToNot(HaveOccurred())
Expect(addr).To(Equal(remoteAddr))
receivedPacketChan <- b[:n]
}()
// Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called.
// Give the Go routine some time to spin up.
time.Sleep(scaleDuration(50 * time.Millisecond))
packetChan <- packetToRead{
addr: remoteAddr,
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
}
Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3})))
// shutdown
close(packetChan)
tr.Close()
})
It("drops non-QUIC packet if the application doesn't process them quickly enough", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
packetChan := make(chan packetToRead)
t, tracer := mocklogging.NewMockTracer(mockCtrl)
tr := &Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: 10,
Tracer: t,
}
tr.init(true)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10))
Expect(err).To(MatchError(context.Canceled))
for i := 0; i < maxQueuedNonQUICPackets; i++ {
packetChan <- packetToRead{
addr: remoteAddr,
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
}
}
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(done)
})
packetChan <- packetToRead{
addr: remoteAddr,
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
}
Eventually(done).Should(BeClosed())
// shutdown
tracer.EXPECT().Close()
close(packetChan)
tr.Close()
})
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234}
DescribeTable("setting the tls.Config.ServerName",
func(expected string, conf *tls.Config, addr net.Addr, host string) {
setTLSConfigServerName(conf, addr, host)
Expect(conf.ServerName).To(Equal(expected))
},
Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"),
Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"),
Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"),
Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""),
)
})
type mockSyscallConn struct {
net.PacketConn
}
func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) {
return nil, errors.New("mocked")
}