diff --git a/mock_unpacker_test.go b/mock_unpacker_test.go index 0966d588..7a36bebf 100644 --- a/mock_unpacker_test.go +++ b/mock_unpacker_test.go @@ -35,7 +35,7 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder { } // Unpack mocks base method -func (m *MockUnpacker) Unpack(arg0 *wire.ExtendedHeader, arg1 []byte) (*unpackedPacket, error) { +func (m *MockUnpacker) Unpack(arg0 *wire.Header, arg1 []byte) (*unpackedPacket, error) { ret := m.ctrl.Call(m, "Unpack", arg0, arg1) ret0, _ := ret[0].(*unpackedPacket) ret1, _ := ret[1].(error) diff --git a/packet_unpacker.go b/packet_unpacker.go index 0f956d57..c8887cdd 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -11,7 +11,8 @@ import ( ) type unpackedPacket struct { - packetNumber protocol.PacketNumber + packetNumber protocol.PacketNumber // the decoded packet number + hdr *wire.ExtendedHeader encryptionLevel protocol.EncryptionLevel frames []wire.Frame } @@ -40,11 +41,30 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker { } } -func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) { +func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { + r := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(r, u.version) + if err != nil { + return nil, fmt.Errorf("error parsing extended header: %s", err) + } + extHdr.Raw = data[:len(data)-r.Len()] + data = data[len(data)-r.Len():] + + if hdr.IsLongHeader { + if hdr.Length < protocol.ByteCount(extHdr.PacketNumberLen) { + return nil, fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", extHdr.Length, extHdr.PacketNumberLen) + } + if protocol.ByteCount(len(data))+protocol.ByteCount(extHdr.PacketNumberLen) < extHdr.Length { + return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(extHdr.PacketNumberLen), extHdr.Length) + } + data = data[:int(extHdr.Length)-int(extHdr.PacketNumberLen)] + // TODO(#1312): implement parsing of compound packets + } + pn := protocol.DecodePacketNumber( - hdr.PacketNumberLen, + extHdr.PacketNumberLen, u.largestRcvdPacketNumber, - hdr.PacketNumber, + extHdr.PacketNumber, ) buf := *getPacketBuffer() @@ -53,19 +73,18 @@ func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpacke var decrypted []byte var encryptionLevel protocol.EncryptionLevel - var err error switch hdr.Type { case protocol.PacketTypeInitial: - decrypted, err = u.aead.OpenInitial(buf, data, pn, hdr.Raw) + decrypted, err = u.aead.OpenInitial(buf, data, pn, extHdr.Raw) encryptionLevel = protocol.EncryptionInitial case protocol.PacketTypeHandshake: - decrypted, err = u.aead.OpenHandshake(buf, data, pn, hdr.Raw) + decrypted, err = u.aead.OpenHandshake(buf, data, pn, extHdr.Raw) encryptionLevel = protocol.EncryptionHandshake default: if hdr.IsLongHeader { return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) } - decrypted, err = u.aead.Open1RTT(buf, data, pn, hdr.Raw) + decrypted, err = u.aead.Open1RTT(buf, data, pn, extHdr.Raw) encryptionLevel = protocol.Encryption1RTT } if err != nil { @@ -81,6 +100,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpacke } return &unpackedPacket{ + hdr: extHdr, packetNumber: pn, encryptionLevel: encryptionLevel, frames: fs, diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index d64891e7..ec67a5ce 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -14,82 +14,187 @@ import ( ) var _ = Describe("Packet Unpacker", func() { + const version = protocol.VersionTLS var ( unpacker *packetUnpacker - hdr *wire.ExtendedHeader aead *MockQuicAEAD + connID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} ) + getHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { + buf := &bytes.Buffer{} + Expect(extHdr.Write(buf, protocol.VersionWhatever)).To(Succeed()) + hdr, err := wire.ParseHeader(bytes.NewReader(buf.Bytes()), connID.Len()) + Expect(err).ToNot(HaveOccurred()) + return hdr, buf.Bytes() + } + BeforeEach(func() { aead = NewMockQuicAEAD(mockCtrl) - hdr = &wire.ExtendedHeader{ - PacketNumber: 10, - PacketNumberLen: 1, - Raw: []byte{0x04, 0x4c, 0x01}, - } - unpacker = newPacketUnpacker(aead, protocol.VersionWhatever).(*packetUnpacker) + unpacker = newPacketUnpacker(aead, version).(*packetUnpacker) }) It("errors if the packet doesn't contain any payload", func() { - data := []byte("foobar") - aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), hdr.PacketNumber, hdr.Raw).Return([]byte{}, nil) + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, + PacketNumber: 42, + PacketNumberLen: protocol.PacketNumberLen2, + } + hdr, hdrRaw := getHeader(extHdr) + data := append(hdrRaw, []byte("foobar")...) // add some payload + // return an empty (unencrypted) payload + aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{}, nil) _, err := unpacker.Unpack(hdr, data) Expect(err).To(MatchError(qerr.MissingPayload)) }) It("opens Initial packets", func() { - hdr.IsLongHeader = true - hdr.Type = protocol.PacketTypeInitial - aead.EXPECT().OpenInitial(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil) - packet, err := unpacker.Unpack(hdr, nil) + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Length: 3 + 6, // packet number len + payload + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 2, + PacketNumberLen: 3, + } + hdr, hdrRaw := getHeader(extHdr) + aead.EXPECT().OpenInitial(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil) + packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...)) Expect(err).ToNot(HaveOccurred()) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) }) It("opens Handshake packets", func() { - hdr.IsLongHeader = true - hdr.Type = protocol.PacketTypeHandshake - aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil) - packet, err := unpacker.Unpack(hdr, nil) + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Length: 3 + 6, // packet number len + payload + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 2, + PacketNumberLen: 3, + } + hdr, hdrRaw := getHeader(extHdr) + aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil) + packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...)) Expect(err).ToNot(HaveOccurred()) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionHandshake)) }) + It("errors on packets that are smaller than the length in the packet header", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Length: 1000, + DestConnectionID: connID, + Version: version, + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + hdr, hdrRaw := getHeader(extHdr) + data := append(hdrRaw, make([]byte, 500-2 /* for packet number length */)...) + _, err := unpacker.Unpack(hdr, data) + Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + + It("errors when receiving a packet that has a length smaller than the packet number length", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + DestConnectionID: connID, + Type: protocol.PacketTypeHandshake, + Length: 3, + Version: protocol.VersionTLS, + }, + PacketNumberLen: protocol.PacketNumberLen4, + } + hdr, hdrRaw := getHeader(extHdr) + _, err := unpacker.Unpack(hdr, hdrRaw) + Expect(err).To(MatchError("packet length (3 bytes) shorter than packet number (4 bytes)")) + }) + + It("cuts packets to the right length", func() { + pnLen := protocol.PacketNumberLen2 + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + DestConnectionID: connID, + Type: protocol.PacketTypeHandshake, + Length: 456, + Version: protocol.VersionTLS, + }, + PacketNumberLen: pnLen, + } + payloadLen := 456 - int(pnLen) + hdr, hdrRaw := getHeader(extHdr) + data := append(hdrRaw, make([]byte, payloadLen)...) + aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) { + Expect(payload).To(HaveLen(payloadLen)) + return []byte{0}, nil + }) + _, err := unpacker.Unpack(hdr, data) + Expect(err).ToNot(HaveOccurred()) + }) + It("returns the error when unpacking fails", func() { - hdr.IsLongHeader = true - hdr.Type = protocol.PacketTypeHandshake + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Length: 3, // packet number len + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 2, + PacketNumberLen: 3, + } + hdr, hdrRaw := getHeader(extHdr) aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) - _, err := unpacker.Unpack(hdr, nil) + _, err := unpacker.Unpack(hdr, hdrRaw) Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err"))) }) It("decodes the packet number", func() { firstHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, PacketNumber: 0x1337, PacketNumberLen: 2, } aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil) - packet, err := unpacker.Unpack(firstHdr, nil) + packet, err := unpacker.Unpack(getHeader(firstHdr)) Expect(err).ToNot(HaveOccurred()) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337))) // the real packet number is 0x1338, but only the last byte is sent secondHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, PacketNumber: 0x38, PacketNumberLen: 1, } // expect the call with the decoded packet number aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil) - packet, err = unpacker.Unpack(secondHdr, nil) + packet, err = unpacker.Unpack(getHeader(secondHdr)) Expect(err).ToNot(HaveOccurred()) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338))) }) It("unpacks the frames", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, + PacketNumber: 0x1337, + PacketNumberLen: 2, + } buf := &bytes.Buffer{} (&wire.PingFrame{}).Write(buf, protocol.VersionWhatever) (&wire.DataBlockedFrame{}).Write(buf, protocol.VersionWhatever) - aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return(buf.Bytes(), nil) - packet, err := unpacker.Unpack(hdr, nil) + hdr, hdrRaw := getHeader(extHdr) + aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return(buf.Bytes(), nil) + packet, err := unpacker.Unpack(hdr, append(hdrRaw, buf.Bytes()...)) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.DataBlockedFrame{}})) }) diff --git a/session.go b/session.go index 876709af..3ad449ba 100644 --- a/session.go +++ b/session.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "context" "crypto/tls" "errors" @@ -22,7 +21,7 @@ import ( ) type unpacker interface { - Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) + Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) } type streamGetter interface { @@ -483,27 +482,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { return nil } - data := p.data - r := bytes.NewReader(data) - hdr, err := p.hdr.ParseExtended(r, s.version) - if err != nil { - return fmt.Errorf("error parsing extended header: %s", err) - } - hdr.Raw = data[:len(data)-r.Len()] - data = data[len(data)-r.Len():] - - if hdr.IsLongHeader { - if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) { - return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen) - } - if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length { - return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length) - } - data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)] - // TODO(#1312): implement parsing of compound packets - } - - packet, err := s.unpacker.Unpack(hdr, data) + packet, err := s.unpacker.Unpack(p.hdr, p.data) // if the decryption failed, this might be a packet sent by an attacker if err != nil { return err @@ -511,13 +490,13 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { if s.logger.Debug() { s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), p.hdr.DestConnectionID, packet.encryptionLevel) - hdr.Log(s.logger) + packet.hdr.Log(s.logger) } // The server can change the source connection ID with the first Handshake packet. - if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { - s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID) - s.destConnID = hdr.SrcConnectionID + if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) { + s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", p.hdr.SrcConnectionID) + s.destConnID = p.hdr.SrcConnectionID s.packer.ChangeDestConnectionID(s.destConnID) } @@ -536,7 +515,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { // If this is a Retry packet, there's no need to send an ACK. // The session will be closed and recreated as soon as the crypto setup processed the HRR. - if hdr.Type != protocol.PacketTypeRetry { + if p.hdr.Type != protocol.PacketTypeRetry { isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames) if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, p.rcvTime, isRetransmittable); err != nil { return err diff --git a/session_test.go b/session_test.go index fd61154c..9e074cf2 100644 --- a/session_test.go +++ b/session_test.go @@ -533,62 +533,6 @@ var _ = Describe("Session", func() { })).To(Succeed()) }) - It("errors on packets that are smaller than the length in the packet header", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Length: 1000, - DestConnectionID: connID, - Version: protocol.VersionTLS, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - data := getData(hdr) - data = append(data, make([]byte, 500-2 /* for packet number length */)...) - Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - - It("errors when receiving a packet that has a length smaller than the packet number length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: connID, - Type: protocol.PacketTypeHandshake, - Length: 3, - Version: protocol.VersionTLS, - }, - PacketNumberLen: protocol.PacketNumberLen4, - } - data := getData(hdr) - Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(MatchError("packet length (3 bytes) shorter than packet number (4 bytes)")) - }) - - It("cuts packets to the right length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - pnLen := protocol.PacketNumberLen2 - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: connID, - Type: protocol.PacketTypeHandshake, - Length: 456, - Version: protocol.VersionTLS, - }, - PacketNumberLen: pnLen, - } - payloadLen := 456 - int(pnLen) - data := getData(hdr) - data = append(data, make([]byte, payloadLen)...) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) { - Expect(data).To(HaveLen(payloadLen)) - return &unpackedPacket{}, nil - }) - Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(Succeed()) - }) - Context("updating the remote address", func() { It("doesn't support connection migration", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)