mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-03-31 10:47:35 +03:00
444 lines
13 KiB
Go
444 lines
13 KiB
Go
package quic
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"errors"
|
|
"net"
|
|
"syscall"
|
|
"time"
|
|
|
|
tls "github.com/refraction-networking/utls"
|
|
|
|
mocklogging "github.com/refraction-networking/uquic/internal/mocks/logging"
|
|
"github.com/refraction-networking/uquic/internal/protocol"
|
|
"github.com/refraction-networking/uquic/internal/wire"
|
|
"github.com/refraction-networking/uquic/logging"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
var _ = Describe("Transport", func() {
|
|
type packetToRead struct {
|
|
addr net.Addr
|
|
data []byte
|
|
err error
|
|
}
|
|
|
|
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
|
|
b, err := (&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: t,
|
|
DestConnectionID: connID,
|
|
Length: length,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}).Append(nil, protocol.Version1)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
return b
|
|
}
|
|
|
|
getPacket := func(connID protocol.ConnectionID) []byte {
|
|
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
|
|
}
|
|
|
|
newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
|
|
conn := NewMockPacketConn(mockCtrl)
|
|
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
|
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
|
|
p, ok := <-packetChan
|
|
if !ok {
|
|
return 0, nil, errors.New("closed")
|
|
}
|
|
return copy(b, p.data), p.addr, p.err
|
|
}).AnyTimes()
|
|
// for shutdown
|
|
conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
|
|
return conn
|
|
}
|
|
|
|
It("handles packets for different packet handlers on the same packet conn", func() {
|
|
packetChan := make(chan packetToRead)
|
|
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
|
tr.init(true)
|
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
|
tr.handlerMap = phm
|
|
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
|
|
|
handled := make(chan struct{}, 2)
|
|
phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
|
|
h := NewMockPacketHandler(mockCtrl)
|
|
h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
|
|
defer GinkgoRecover()
|
|
connID, err := wire.ParseConnectionID(p.data, 0)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(connID).To(Equal(connID1))
|
|
handled <- struct{}{}
|
|
})
|
|
return h, true
|
|
})
|
|
phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
|
|
h := NewMockPacketHandler(mockCtrl)
|
|
h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
|
|
defer GinkgoRecover()
|
|
connID, err := wire.ParseConnectionID(p.data, 0)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(connID).To(Equal(connID2))
|
|
handled <- struct{}{}
|
|
})
|
|
return h, true
|
|
})
|
|
|
|
packetChan <- packetToRead{data: getPacket(connID1)}
|
|
packetChan <- packetToRead{data: getPacket(connID2)}
|
|
|
|
Eventually(handled).Should(Receive())
|
|
Eventually(handled).Should(Receive())
|
|
|
|
// shutdown
|
|
phm.EXPECT().Close(gomock.Any())
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
It("closes listeners", func() {
|
|
packetChan := make(chan packetToRead)
|
|
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
|
defer tr.Close()
|
|
ln, err := tr.Listen(&tls.Config{}, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
|
tr.handlerMap = phm
|
|
|
|
Expect(ln.Close()).To(Succeed())
|
|
|
|
// shutdown
|
|
phm.EXPECT().Close(gomock.Any())
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
It("closes transport concurrently with listener", func() {
|
|
// try 10 times to trigger race conditions
|
|
for i := 0; i < 10; i++ {
|
|
packetChan := make(chan packetToRead)
|
|
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
|
ln, err := tr.Listen(&tls.Config{}, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
ch := make(chan bool)
|
|
// Close transport and listener concurrently.
|
|
go func() {
|
|
ch <- true
|
|
Expect(ln.Close()).To(Succeed())
|
|
ch <- true
|
|
}()
|
|
<-ch
|
|
close(packetChan)
|
|
Expect(tr.Close()).To(Succeed())
|
|
<-ch
|
|
}
|
|
})
|
|
|
|
It("drops unparseable QUIC packets", func() {
|
|
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
|
packetChan := make(chan packetToRead)
|
|
t, tracer := mocklogging.NewMockTracer(mockCtrl)
|
|
tr := &Transport{
|
|
Conn: newMockPacketConn(packetChan),
|
|
ConnectionIDLength: 10,
|
|
Tracer: t,
|
|
}
|
|
tr.init(true)
|
|
dropped := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
|
|
packetChan <- packetToRead{
|
|
addr: addr,
|
|
data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3},
|
|
}
|
|
Eventually(dropped).Should(BeClosed())
|
|
|
|
// shutdown
|
|
tracer.EXPECT().Close()
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
It("closes when reading from the conn fails", func() {
|
|
packetChan := make(chan packetToRead)
|
|
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
|
defer tr.Close()
|
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
|
tr.init(true)
|
|
tr.handlerMap = phm
|
|
|
|
done := make(chan struct{})
|
|
phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
|
|
packetChan <- packetToRead{err: errors.New("read failed")}
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
// shutdown
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
It("continues listening after temporary errors", func() {
|
|
packetChan := make(chan packetToRead)
|
|
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
|
defer tr.Close()
|
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
|
tr.init(true)
|
|
tr.handlerMap = phm
|
|
|
|
tempErr := deadlineError{}
|
|
Expect(tempErr.Temporary()).To(BeTrue())
|
|
packetChan <- packetToRead{err: tempErr}
|
|
// don't expect any calls to phm.Close
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// shutdown
|
|
phm.EXPECT().Close(gomock.Any())
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
It("handles short header packets resets", func() {
|
|
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
|
packetChan := make(chan packetToRead)
|
|
tr := Transport{
|
|
Conn: newMockPacketConn(packetChan),
|
|
ConnectionIDLength: connID.Len(),
|
|
}
|
|
tr.init(true)
|
|
defer tr.Close()
|
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
|
tr.handlerMap = phm
|
|
|
|
var token protocol.StatelessResetToken
|
|
rand.Read(token[:])
|
|
|
|
var b []byte
|
|
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
b = append(b, token[:]...)
|
|
conn := NewMockPacketHandler(mockCtrl)
|
|
gomock.InOrder(
|
|
phm.EXPECT().Get(connID).Return(conn, true),
|
|
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
|
|
Expect(p.data).To(Equal(b))
|
|
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
|
|
}),
|
|
)
|
|
packetChan <- packetToRead{data: b}
|
|
|
|
// shutdown
|
|
phm.EXPECT().Close(gomock.Any())
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
It("handles stateless resets", func() {
|
|
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
|
packetChan := make(chan packetToRead)
|
|
tr := Transport{
|
|
Conn: newMockPacketConn(packetChan),
|
|
ConnectionIDLength: connID.Len(),
|
|
}
|
|
tr.init(true)
|
|
defer tr.Close()
|
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
|
tr.handlerMap = phm
|
|
|
|
var token protocol.StatelessResetToken
|
|
rand.Read(token[:])
|
|
|
|
var b []byte
|
|
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
b = append(b, token[:]...)
|
|
conn := NewMockPacketHandler(mockCtrl)
|
|
destroyed := make(chan struct{})
|
|
gomock.InOrder(
|
|
phm.EXPECT().Get(connID),
|
|
phm.EXPECT().GetByResetToken(token).Return(conn, true),
|
|
conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
|
Expect(err).To(MatchError(&StatelessResetError{Token: token}))
|
|
close(destroyed)
|
|
}),
|
|
)
|
|
packetChan <- packetToRead{data: b}
|
|
Eventually(destroyed).Should(BeClosed())
|
|
|
|
// shutdown
|
|
phm.EXPECT().Close(gomock.Any())
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
It("sends stateless resets", func() {
|
|
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
|
packetChan := make(chan packetToRead)
|
|
conn := newMockPacketConn(packetChan)
|
|
tr := Transport{
|
|
Conn: conn,
|
|
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
|
|
ConnectionIDLength: connID.Len(),
|
|
}
|
|
tr.init(true)
|
|
defer tr.Close()
|
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
|
tr.handlerMap = phm
|
|
|
|
var b []byte
|
|
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
|
|
|
|
var token protocol.StatelessResetToken
|
|
rand.Read(token[:])
|
|
written := make(chan struct{})
|
|
gomock.InOrder(
|
|
phm.EXPECT().Get(connID),
|
|
phm.EXPECT().GetByResetToken(gomock.Any()),
|
|
phm.EXPECT().GetStatelessResetToken(connID).Return(token),
|
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) (int, error) {
|
|
defer close(written)
|
|
Expect(bytes.Contains(b, token[:])).To(BeTrue())
|
|
return len(b), nil
|
|
}),
|
|
)
|
|
packetChan <- packetToRead{data: b}
|
|
Eventually(written).Should(BeClosed())
|
|
|
|
// shutdown
|
|
phm.EXPECT().Close(gomock.Any())
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
It("closes uninitialized Transport and closes underlying PacketConn", func() {
|
|
packetChan := make(chan packetToRead)
|
|
pconn := newMockPacketConn(packetChan)
|
|
|
|
tr := &Transport{
|
|
Conn: pconn,
|
|
createdConn: true, // owns pconn
|
|
}
|
|
// NO init
|
|
|
|
// shutdown
|
|
close(packetChan)
|
|
pconn.EXPECT().Close()
|
|
Expect(tr.Close()).To(Succeed())
|
|
})
|
|
|
|
It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() {
|
|
packetChan := make(chan packetToRead)
|
|
pconn := newMockPacketConn(packetChan)
|
|
syscallconn := &mockSyscallConn{pconn}
|
|
|
|
tr := &Transport{
|
|
Conn: syscallconn,
|
|
}
|
|
|
|
err := tr.init(false)
|
|
Expect(err).To(HaveOccurred())
|
|
conns := getMultiplexer().(*connMultiplexer).conns
|
|
Expect(len(conns)).To(BeZero())
|
|
})
|
|
|
|
It("allows receiving non-QUIC packets", func() {
|
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
|
packetChan := make(chan packetToRead)
|
|
tr := &Transport{
|
|
Conn: newMockPacketConn(packetChan),
|
|
ConnectionIDLength: 10,
|
|
}
|
|
tr.init(true)
|
|
receivedPacketChan := make(chan []byte)
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
b := make([]byte, 100)
|
|
n, addr, err := tr.ReadNonQUICPacket(context.Background(), b)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(addr).To(Equal(remoteAddr))
|
|
receivedPacketChan <- b[:n]
|
|
}()
|
|
// Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called.
|
|
// Give the Go routine some time to spin up.
|
|
time.Sleep(scaleDuration(50 * time.Millisecond))
|
|
packetChan <- packetToRead{
|
|
addr: remoteAddr,
|
|
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
|
|
}
|
|
|
|
Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3})))
|
|
|
|
// shutdown
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
It("drops non-QUIC packet if the application doesn't process them quickly enough", func() {
|
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
|
packetChan := make(chan packetToRead)
|
|
t, tracer := mocklogging.NewMockTracer(mockCtrl)
|
|
tr := &Transport{
|
|
Conn: newMockPacketConn(packetChan),
|
|
ConnectionIDLength: 10,
|
|
Tracer: t,
|
|
}
|
|
tr.init(true)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
_, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10))
|
|
Expect(err).To(MatchError(context.Canceled))
|
|
|
|
for i := 0; i < maxQueuedNonQUICPackets; i++ {
|
|
packetChan <- packetToRead{
|
|
addr: remoteAddr,
|
|
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
|
|
}
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(done)
|
|
})
|
|
packetChan <- packetToRead{
|
|
addr: remoteAddr,
|
|
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
|
|
}
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
// shutdown
|
|
tracer.EXPECT().Close()
|
|
close(packetChan)
|
|
tr.Close()
|
|
})
|
|
|
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234}
|
|
DescribeTable("setting the tls.Config.ServerName",
|
|
func(expected string, conf *tls.Config, addr net.Addr, host string) {
|
|
setTLSConfigServerName(conf, addr, host)
|
|
Expect(conf.ServerName).To(Equal(expected))
|
|
},
|
|
Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"),
|
|
Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"),
|
|
Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"),
|
|
Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""),
|
|
)
|
|
})
|
|
|
|
type mockSyscallConn struct {
|
|
net.PacketConn
|
|
}
|
|
|
|
func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) {
|
|
return nil, errors.New("mocked")
|
|
}
|