mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 21:57:36 +03:00
improve packet handling in the client
This commit is contained in:
parent
c1d8c8940e
commit
dd5d376d94
2 changed files with 54 additions and 51 deletions
34
client.go
34
client.go
|
@ -284,28 +284,29 @@ func (c *client) listen() {
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
c.handlePacket(addr, data[:n])
|
if err := c.handlePacket(addr, data[:n]); err != nil {
|
||||||
|
c.logger.Errorf("error handling packet: %s", err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||||
rcvTime := time.Now()
|
rcvTime := time.Now()
|
||||||
|
|
||||||
r := bytes.NewReader(packet)
|
r := bytes.NewReader(packet)
|
||||||
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
||||||
|
// drop the packet if we can't parse the header
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
return fmt.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||||
// drop this packet if we can't parse the header
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// reject packets with truncated connection id if we didn't request truncation
|
// reject packets with truncated connection id if we didn't request truncation
|
||||||
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||||
return
|
return errors.New("received packet with truncated connection ID, but didn't request truncation")
|
||||||
}
|
}
|
||||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||||
|
|
||||||
if hdr.IsLongHeader && !hdr.DestConnectionID.Equal(hdr.SrcConnectionID) {
|
if hdr.IsLongHeader && !hdr.DestConnectionID.Equal(hdr.SrcConnectionID) {
|
||||||
c.logger.Errorf("receiving packets with different destination and source connection IDs not supported")
|
return fmt.Errorf("receiving packets with different destination and source connection IDs not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
|
@ -314,7 +315,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||||
// reject packets with the wrong connection ID
|
// reject packets with the wrong connection ID
|
||||||
// TODO(#1003): add support for server-chosen connection IDs
|
// TODO(#1003): add support for server-chosen connection IDs
|
||||||
if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.connectionID) {
|
if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.connectionID) {
|
||||||
return
|
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.connectionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if hdr.ResetFlag {
|
if hdr.ResetFlag {
|
||||||
|
@ -322,31 +323,29 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||||
// check if the remote address and the connection ID match
|
// check if the remote address and the connection ID match
|
||||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||||
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.connectionID) {
|
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.connectionID) {
|
||||||
c.logger.Infof("Received a spoofed Public Reset. Ignoring.")
|
return errors.New("Received a spoofed Public Reset")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
pr, err := wire.ParsePublicReset(r)
|
pr, err := wire.ParsePublicReset(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
|
||||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||||
return
|
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle Version Negotiation Packets
|
// handle Version Negotiation Packets
|
||||||
if hdr.IsVersionNegotiation {
|
if hdr.IsVersionNegotiation {
|
||||||
// ignore delayed / duplicated version negotiation packets
|
// ignore delayed / duplicated version negotiation packets
|
||||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||||
return
|
return errors.New("received a delayed Version Negotiation Packet")
|
||||||
}
|
}
|
||||||
|
|
||||||
// version negotiation packets have no payload
|
// version negotiation packets have no payload
|
||||||
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
||||||
c.session.Close(err)
|
c.session.Close(err)
|
||||||
}
|
}
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is the first packet we are receiving
|
// this is the first packet we are receiving
|
||||||
|
@ -364,6 +363,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||||
data: packet[len(packet)-r.Len():],
|
data: packet[len(packet)-r.Len():],
|
||||||
rcvTime: rcvTime,
|
rcvTime: rcvTime,
|
||||||
})
|
})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
|
@ -389,7 +389,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
c.initialVersion = c.version
|
c.initialVersion = c.version
|
||||||
c.version = newVersion
|
c.version = newVersion
|
||||||
var err error
|
var err error
|
||||||
c.connectionID, err = protocol.GenerateConnectionID()
|
c.connectionID, err = generateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -302,7 +303,8 @@ var _ = Describe("Client", func() {
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
err := ph.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever)
|
err := ph.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
cl.handlePacket(nil, b.Bytes())
|
err = cl.handlePacket(nil, b.Bytes())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cl.versionNegotiated).To(BeTrue())
|
Expect(cl.versionNegotiated).To(BeTrue())
|
||||||
Expect(cl.versionNegotiationChan).To(BeClosed())
|
Expect(cl.versionNegotiationChan).To(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -392,15 +394,18 @@ var _ = Describe("Client", func() {
|
||||||
go cl.dial()
|
go cl.dial()
|
||||||
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1))
|
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1))
|
||||||
cl.config = &Config{Versions: []protocol.VersionNumber{77, 78}}
|
cl.config = &Config{Versions: []protocol.VersionNumber{77, 78}}
|
||||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{77}))
|
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{77}))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
|
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
|
||||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{78}))
|
err = cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{78}))
|
||||||
|
Expect(err).To(MatchError("received a delayed Version Negotiation Packet"))
|
||||||
Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
|
Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if no matching version is found", func() {
|
It("errors if no matching version is found", func() {
|
||||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1}))
|
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1}))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cl.session.(*mockSession).closed).To(BeTrue())
|
Expect(cl.session.(*mockSession).closed).To(BeTrue())
|
||||||
Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion))
|
Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion))
|
||||||
})
|
})
|
||||||
|
@ -409,7 +414,8 @@ var _ = Describe("Client", func() {
|
||||||
v := protocol.VersionNumber(1234)
|
v := protocol.VersionNumber(1234)
|
||||||
Expect(v).ToNot(Equal(cl.version))
|
Expect(v).ToNot(Equal(cl.version))
|
||||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v}))
|
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v}))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cl.session.(*mockSession).closed).To(BeTrue())
|
Expect(cl.session.(*mockSession).closed).To(BeTrue())
|
||||||
Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion))
|
Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion))
|
||||||
})
|
})
|
||||||
|
@ -417,29 +423,24 @@ var _ = Describe("Client", func() {
|
||||||
It("changes to the version preferred by the quic.Config", func() {
|
It("changes to the version preferred by the quic.Config", func() {
|
||||||
config := &Config{Versions: []protocol.VersionNumber{1234, 4321}}
|
config := &Config{Versions: []protocol.VersionNumber{1234, 4321}}
|
||||||
cl.config = config
|
cl.config = config
|
||||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234}))
|
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234}))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
|
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores delayed version negotiation packets", func() {
|
|
||||||
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
|
|
||||||
cl.versionNegotiated = true
|
|
||||||
Expect(sess.packetCount).To(BeZero())
|
|
||||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1}))
|
|
||||||
Expect(cl.versionNegotiated).To(BeTrue())
|
|
||||||
Expect(sess.packetCount).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("drops version negotiation packets that contain the offered version", func() {
|
It("drops version negotiation packets that contain the offered version", func() {
|
||||||
ver := cl.version
|
ver := cl.version
|
||||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver}))
|
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver}))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cl.version).To(Equal(ver))
|
Expect(cl.version).To(Equal(ver))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores packets with an invalid public header", func() {
|
It("ignores packets with an invalid public header", func() {
|
||||||
cl.handlePacket(addr, []byte("invalid packet"))
|
err := cl.handlePacket(addr, []byte("invalid packet"))
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("error parsing packet from"))
|
||||||
Expect(sess.packetCount).To(BeZero())
|
Expect(sess.packetCount).To(BeZero())
|
||||||
Expect(sess.closed).To(BeFalse())
|
Expect(sess.closed).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
@ -447,27 +448,33 @@ var _ = Describe("Client", func() {
|
||||||
It("ignores packets without connection id, if it didn't request connection id trunctation", func() {
|
It("ignores packets without connection id, if it didn't request connection id trunctation", func() {
|
||||||
cl.config = &Config{RequestConnectionIDOmission: false}
|
cl.config = &Config{RequestConnectionIDOmission: false}
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
(&wire.Header{
|
err := (&wire.Header{
|
||||||
OmitConnectionID: true,
|
OmitConnectionID: true,
|
||||||
|
SrcConnectionID: connID,
|
||||||
|
DestConnectionID: connID,
|
||||||
PacketNumber: 1,
|
PacketNumber: 1,
|
||||||
PacketNumberLen: 1,
|
PacketNumberLen: 1,
|
||||||
}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
|
}).Write(buf, protocol.PerspectiveServer, versionGQUICFrames)
|
||||||
cl.handlePacket(addr, buf.Bytes())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
err = cl.handlePacket(addr, buf.Bytes())
|
||||||
|
Expect(err).To(MatchError("received packet with truncated connection ID, but didn't request truncation"))
|
||||||
Expect(sess.packetCount).To(BeZero())
|
Expect(sess.packetCount).To(BeZero())
|
||||||
Expect(sess.closed).To(BeFalse())
|
Expect(sess.closed).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores packets with the wrong connection ID", func() {
|
It("ignores packets with the wrong connection ID", func() {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
connID2 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}
|
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||||
Expect(connID).ToNot(Equal(connID2))
|
Expect(connID).ToNot(Equal(connID2))
|
||||||
(&wire.Header{
|
err := (&wire.Header{
|
||||||
DestConnectionID: connID2,
|
DestConnectionID: connID2,
|
||||||
SrcConnectionID: connID2,
|
SrcConnectionID: connID2,
|
||||||
PacketNumber: 1,
|
PacketNumber: 1,
|
||||||
PacketNumberLen: 1,
|
PacketNumberLen: 1,
|
||||||
}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
|
}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||||
cl.handlePacket(addr, buf.Bytes())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
err = cl.handlePacket(addr, buf.Bytes())
|
||||||
|
Expect(err).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID)))
|
||||||
Expect(sess.packetCount).To(BeZero())
|
Expect(sess.packetCount).To(BeZero())
|
||||||
Expect(sess.closed).To(BeFalse())
|
Expect(sess.closed).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
@ -626,30 +633,26 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
Context("Public Reset handling", func() {
|
Context("Public Reset handling", func() {
|
||||||
It("closes the session when receiving a Public Reset", func() {
|
It("closes the session when receiving a Public Reset", func() {
|
||||||
cl.handlePacket(addr, wire.WritePublicReset(cl.connectionID, 1, 0))
|
err := cl.handlePacket(addr, wire.WritePublicReset(cl.connectionID, 1, 0))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cl.session.(*mockSession).closed).To(BeTrue())
|
Expect(cl.session.(*mockSession).closed).To(BeTrue())
|
||||||
Expect(cl.session.(*mockSession).closedRemote).To(BeTrue())
|
Expect(cl.session.(*mockSession).closedRemote).To(BeTrue())
|
||||||
Expect(cl.session.(*mockSession).closeReason.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset))
|
Expect(cl.session.(*mockSession).closeReason.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores Public Resets with the wrong connection ID", func() {
|
|
||||||
connID2 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}
|
|
||||||
Expect(connID).ToNot(Equal(connID2))
|
|
||||||
cl.handlePacket(addr, wire.WritePublicReset(connID2, 1, 0))
|
|
||||||
Expect(cl.session.(*mockSession).closed).To(BeFalse())
|
|
||||||
Expect(cl.session.(*mockSession).closedRemote).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores Public Resets from the wrong remote address", func() {
|
It("ignores Public Resets from the wrong remote address", func() {
|
||||||
spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678}
|
spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678}
|
||||||
cl.handlePacket(spoofedAddr, wire.WritePublicReset(cl.connectionID, 1, 0))
|
err := cl.handlePacket(spoofedAddr, wire.WritePublicReset(cl.connectionID, 1, 0))
|
||||||
|
Expect(err).To(MatchError("Received a spoofed Public Reset"))
|
||||||
Expect(cl.session.(*mockSession).closed).To(BeFalse())
|
Expect(cl.session.(*mockSession).closed).To(BeFalse())
|
||||||
Expect(cl.session.(*mockSession).closedRemote).To(BeFalse())
|
Expect(cl.session.(*mockSession).closedRemote).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores unparseable Public Resets", func() {
|
It("ignores unparseable Public Resets", func() {
|
||||||
pr := wire.WritePublicReset(cl.connectionID, 1, 0)
|
pr := wire.WritePublicReset(cl.connectionID, 1, 0)
|
||||||
cl.handlePacket(addr, pr[:len(pr)-5])
|
err := cl.handlePacket(addr, pr[:len(pr)-5])
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("Received a Public Reset. An error occurred parsing the packet"))
|
||||||
Expect(cl.session.(*mockSession).closed).To(BeFalse())
|
Expect(cl.session.(*mockSession).closed).To(BeFalse())
|
||||||
Expect(cl.session.(*mockSession).closedRemote).To(BeFalse())
|
Expect(cl.session.(*mockSession).closedRemote).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue