From 97e734e97324b8f254e09cb2078b8d292958d19e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 21 May 2018 12:01:55 +0800 Subject: [PATCH] refactor packet handling functions in the client --- client.go | 86 +++++++++++++++++++----------------- client_test.go | 116 +++++++++++++++++++++++++++---------------------- 2 files changed, 109 insertions(+), 93 deletions(-) diff --git a/client.go b/client.go index 9dbf05d0..b3b596d0 100644 --- a/client.go +++ b/client.go @@ -324,70 +324,84 @@ func (c *client) listen() { } break } - if err := c.handlePacket(addr, data[:n]); err != nil { - c.logger.Errorf("error handling packet: %s", err.Error()) - } + c.handleRead(addr, data[:n]) } } -func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { +func (c *client) handleRead(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) hdr, err := wire.ParseHeaderSentByServer(r) // drop the packet if we can't parse the header if err != nil { - return fmt.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) - } - // reject packets with truncated connection id if we didn't request truncation - if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission { - return errors.New("received packet with truncated connection ID, but didn't request truncation") + c.logger.Errorf("error handling packet: %s", err) + return } hdr.Raw = packet[:len(packet)-r.Len()] packetData := packet[len(packet)-r.Len():] + c.handlePacket(&receivedPacket{ + remoteAddr: remoteAddr, + header: hdr, + data: packetData, + rcvTime: rcvTime, + }) +} + +func (c *client) handlePacket(p *receivedPacket) { + if err := c.handlePacketImpl(p); err != nil { + c.logger.Errorf("error handling packet: %s", err) + } +} + +func (c *client) handlePacketImpl(p *receivedPacket) error { + // reject packets with truncated connection id if we didn't request truncation + if p.header.OmitConnectionID && !c.config.RequestConnectionIDOmission { + return errors.New("received packet with truncated connection ID, but didn't request truncation") + } c.mutex.Lock() defer c.mutex.Unlock() // handle Version Negotiation Packets - if hdr.IsVersionNegotiation { + if p.header.IsVersionNegotiation { // ignore delayed / duplicated version negotiation packets if c.receivedVersionNegotiationPacket || c.versionNegotiated { return errors.New("received a delayed Version Negotiation Packet") } // version negotiation packets have no payload - if err := c.handleVersionNegotiationPacket(hdr); err != nil { + if err := c.handleVersionNegotiationPacket(p.header); err != nil { c.session.Close(err) } return nil } - if hdr.IsPublicHeader { - return c.handleGQUICPacket(hdr, r, packetData, remoteAddr, rcvTime) + if p.header.IsPublicHeader { + return c.handleGQUICPacket(p) } - return c.handleIETFQUICPacket(hdr, packetData, remoteAddr, rcvTime) + return c.handleIETFQUICPacket(p) } -func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error { +func (c *client) handleIETFQUICPacket(p *receivedPacket) error { // reject packets with the wrong connection ID - if !hdr.DestConnectionID.Equal(c.srcConnID) { - return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID) + if !p.header.DestConnectionID.Equal(c.srcConnID) { + return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID) } - if hdr.IsLongHeader { - switch hdr.Type { + if p.header.IsLongHeader { + switch p.header.Type { case protocol.PacketTypeRetry: if c.receivedRetry { return nil } case protocol.PacketTypeHandshake: default: - return fmt.Errorf("Received unsupported packet type: %s", hdr.Type) + return fmt.Errorf("Received unsupported packet type: %s", p.header.Type) } - if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { - return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) + if protocol.ByteCount(len(p.data)) < p.header.PayloadLen { + return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(p.data), p.header.PayloadLen) } - packetData = packetData[:int(hdr.PayloadLen)] + p.data = p.data[:int(p.header.PayloadLen)] // TODO(#1312): implement parsing of compound packets } @@ -397,29 +411,24 @@ func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot c.versionNegotiated = true } - c.session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - header: hdr, - data: packetData, - rcvTime: rcvTime, - }) + c.session.handlePacket(p) return nil } -func (c *client) handleGQUICPacket(hdr *wire.Header, r *bytes.Reader, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error { +func (c *client) handleGQUICPacket(p *receivedPacket) error { // reject packets with the wrong connection ID - if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.srcConnID) { - return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID) + if !p.header.OmitConnectionID && !p.header.DestConnectionID.Equal(c.srcConnID) { + return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID) } - if hdr.ResetFlag { + if p.header.ResetFlag { cr := c.conn.RemoteAddr() // check if the remote address and the connection ID match // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection - if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.srcConnID) { + if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) { return errors.New("Received a spoofed Public Reset") } - pr, err := wire.ParsePublicReset(r) + pr, err := wire.ParsePublicReset(bytes.NewReader(p.data)) if err != nil { return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err) } @@ -434,12 +443,7 @@ func (c *client) handleGQUICPacket(hdr *wire.Header, r *bytes.Reader, packetData c.versionNegotiated = true } - c.session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - header: hdr, - data: packetData, - rcvTime: rcvTime, - }) + c.session.handlePacket(p) return nil } diff --git a/client_test.go b/client_test.go index e3968518..37b76496 100644 --- a/client_test.go +++ b/client_test.go @@ -388,16 +388,13 @@ var _ = Describe("Client", func() { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().handlePacket(gomock.Any()) cl.session = sess - ph := wire.Header{ + ph := &wire.Header{ PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, DestConnectionID: connID, SrcConnectionID: connID, } - b := &bytes.Buffer{} - err := ph.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - err = cl.handlePacket(nil, b.Bytes()) + err := cl.handlePacketImpl(&receivedPacket{header: ph}) Expect(err).ToNot(HaveOccurred()) Expect(cl.versionNegotiated).To(BeTrue()) }) @@ -439,8 +436,7 @@ var _ = Describe("Client", func() { close(dialed) }() Eventually(sessionChan).Should(HaveLen(1)) - err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2})) - Expect(err).ToNot(HaveOccurred()) + cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2})) Eventually(sessionChan).Should(BeEmpty()) }) @@ -474,7 +470,7 @@ var _ = Describe("Client", func() { return <-sessionChan, nil } - cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2}} + cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2, version3}} dialed := make(chan struct{}) go func() { defer GinkgoRecover() @@ -483,12 +479,12 @@ var _ = Describe("Client", func() { close(dialed) }() Eventually(sessionChan).Should(HaveLen(1)) - err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2})) - Expect(err).ToNot(HaveOccurred()) + cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2})) Eventually(sessionChan).Should(BeEmpty()) - err = cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version3})) - Expect(err).To(MatchError("received a delayed Version Negotiation Packet")) + Expect(cl.version).To(Equal(version2)) + cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version3})) Eventually(dialed).Should(BeClosed()) + Expect(cl.version).To(Equal(version2)) }) It("errors if no matching version is found", func() { @@ -496,8 +492,7 @@ var _ = Describe("Client", func() { sess.EXPECT().Close(gomock.Any()) cl.session = sess cl.config = &Config{Versions: protocol.SupportedVersions} - err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1})) - Expect(err).ToNot(HaveOccurred()) + cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1})) }) It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { @@ -507,8 +502,7 @@ var _ = Describe("Client", func() { v := protocol.VersionNumber(1234) Expect(v).ToNot(Equal(cl.version)) cl.config = &Config{Versions: protocol.SupportedVersions} - err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v})) - Expect(err).ToNot(HaveOccurred()) + cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v})) }) It("changes to the version preferred by the quic.Config", func() { @@ -517,15 +511,13 @@ var _ = Describe("Client", func() { cl.session = sess config := &Config{Versions: []protocol.VersionNumber{1234, 4321}} cl.config = config - err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234})) - Expect(err).ToNot(HaveOccurred()) + cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234})) Expect(cl.version).To(Equal(protocol.VersionNumber(1234))) }) It("drops version negotiation packets that contain the offered version", func() { ver := cl.version - err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver})) - Expect(err).ToNot(HaveOccurred()) + cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver})) Expect(cl.version).To(Equal(ver)) }) }) @@ -533,14 +525,11 @@ var _ = Describe("Client", func() { It("ignores packets with an invalid public header", func() { cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls - err := cl.handlePacket(addr, []byte("invalid packet")) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error parsing packet from")) + cl.handleRead(addr, []byte("invalid packet")) }) It("errors on packets that are smaller than the Payload Length in the packet header", func() { cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls - b := &bytes.Buffer{} hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -550,8 +539,12 @@ var _ = Describe("Client", func() { PacketNumberLen: protocol.PacketNumberLen1, Version: versionIETFFrames, } - Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) - cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...)) + err := cl.handlePacketImpl(&receivedPacket{ + remoteAddr: addr, + header: hdr, + data: make([]byte, 456), + }) + Expect(err).To(MatchError("received a packet with an unexpected connection ID (0x0102030405060708, expected 0x0000000000001337)")) }) It("cuts packets at the payload length", func() { @@ -560,7 +553,6 @@ var _ = Describe("Client", func() { Expect(packet.data).To(HaveLen(123)) }) cl.session = sess - b := &bytes.Buffer{} hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -570,13 +562,15 @@ var _ = Describe("Client", func() { PacketNumberLen: protocol.PacketNumberLen1, Version: versionIETFFrames, } - Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) - err := cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...)) + err := cl.handlePacketImpl(&receivedPacket{ + remoteAddr: addr, + header: hdr, + data: make([]byte, 456), + }) Expect(err).ToNot(HaveOccurred()) }) It("ignores packets with the wrong Long Header Type", func() { - b := &bytes.Buffer{} hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -586,43 +580,48 @@ var _ = Describe("Client", func() { PacketNumberLen: protocol.PacketNumberLen1, Version: versionIETFFrames, } - Expect(hdr.Write(b, protocol.PerspectiveServer, versionIETFFrames)).To(Succeed()) - err := cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...)) + err := cl.handlePacketImpl(&receivedPacket{ + remoteAddr: addr, + header: hdr, + data: make([]byte, 456), + }) Expect(err).To(MatchError("Received unsupported packet type: Initial")) }) It("ignores packets without connection id, if it didn't request connection id trunctation", func() { cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls cl.config = &Config{RequestConnectionIDOmission: false} - buf := &bytes.Buffer{} - err := (&wire.Header{ + hdr := &wire.Header{ OmitConnectionID: true, SrcConnectionID: connID, DestConnectionID: connID, PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen1, - }).Write(buf, protocol.PerspectiveServer, versionGQUICFrames) - Expect(err).ToNot(HaveOccurred()) - err = cl.handlePacket(addr, buf.Bytes()) + PacketNumberLen: 1, + } + err := cl.handlePacketImpl(&receivedPacket{ + remoteAddr: addr, + header: hdr, + }) Expect(err).To(MatchError("received packet with truncated connection ID, but didn't request truncation")) }) It("ignores packets with the wrong destination connection ID", func() { cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls - buf := &bytes.Buffer{} cl.version = versionIETFFrames cl.config = &Config{RequestConnectionIDOmission: false} connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} Expect(connID).ToNot(Equal(connID2)) - err := (&wire.Header{ + hdr := &wire.Header{ DestConnectionID: connID2, SrcConnectionID: connID, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen1, Version: versionIETFFrames, - }).Write(buf, protocol.PerspectiveServer, versionIETFFrames) - Expect(err).ToNot(HaveOccurred()) - err = cl.handlePacket(addr, buf.Bytes()) + } + err := cl.handlePacketImpl(&receivedPacket{ + remoteAddr: addr, + header: hdr, + }) Expect(err).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID))) }) @@ -695,13 +694,13 @@ var _ = Describe("Client", func() { It("only accepts one Retry packet", func() { config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} - sess1 := NewMockPacketHandler(mockCtrl) + sess1 := NewMockQuicSession(mockCtrl) sess1.EXPECT().run().Return(handshake.ErrCloseSessionForRetry) // don't EXPECT any call to handlePacket() - sess2 := NewMockPacketHandler(mockCtrl) + sess2 := NewMockQuicSession(mockCtrl) run := make(chan struct{}) sess2.EXPECT().run().Do(func() { <-run }) - sessions := make(chan *MockPacketHandler, 2) + sessions := make(chan *MockQuicSession, 2) sessions <- sess1 sessions <- sess2 newTLSClientSession = func( @@ -716,7 +715,7 @@ var _ = Describe("Client", func() { paramsChan <-chan handshake.TransportParameters, _ protocol.PacketNumber, _ utils.Logger, - ) (packetHandler, error) { + ) (quicSession, error) { return <-sessions, nil } @@ -795,22 +794,35 @@ var _ = Describe("Client", func() { Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset)) }) cl.session = sess - err := cl.handlePacket(addr, wire.WritePublicReset(cl.destConnID, 1, 0)) - Expect(err).ToNot(HaveOccurred()) + cl.handleRead(addr, wire.WritePublicReset(cl.destConnID, 1, 0)) }) It("ignores Public Resets from the wrong remote address", func() { cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678} - err := cl.handlePacket(spoofedAddr, wire.WritePublicReset(cl.destConnID, 1, 0)) + pr := wire.WritePublicReset(cl.destConnID, 1, 0) + r := bytes.NewReader(pr) + hdr, err := wire.ParseHeaderSentByServer(r) + Expect(err).ToNot(HaveOccurred()) + err = cl.handlePacketImpl(&receivedPacket{ + remoteAddr: spoofedAddr, + header: hdr, + data: pr[len(pr)-r.Len():], + }) Expect(err).To(MatchError("Received a spoofed Public Reset")) }) It("ignores unparseable Public Resets", func() { cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls pr := wire.WritePublicReset(cl.destConnID, 1, 0) - err := cl.handlePacket(addr, pr[:len(pr)-5]) - Expect(err).To(HaveOccurred()) + r := bytes.NewReader(pr) + hdr, err := wire.ParseHeaderSentByServer(r) + Expect(err).ToNot(HaveOccurred()) + err = cl.handlePacketImpl(&receivedPacket{ + remoteAddr: addr, + header: hdr, + data: pr[len(pr)-r.Len() : len(pr)-5], // cut off the last 5 bytes + }) Expect(err.Error()).To(ContainSubstring("Received a Public Reset. An error occurred parsing the packet")) }) })