remove the wire.ShortHeader in favor of more return values (#3535)

This commit is contained in:
Marten Seemann 2022-08-30 14:37:36 +03:00 committed by GitHub
parent 5cd5341545
commit 656f3d2d7d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 125 additions and 131 deletions

View file

@ -26,7 +26,7 @@ import (
type unpacker interface { type unpacker interface {
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
} }
type streamGetter interface { type streamGetter interface {
@ -856,11 +856,13 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
data := rp.data data := rp.data
p := rp p := rp
for len(data) > 0 { for len(data) > 0 {
var destConnID protocol.ConnectionID
if counter > 0 { if counter > 0 {
p = p.Clone() p = p.Clone()
p.data = data p.data = data
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen) var err error
destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
@ -920,7 +922,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
if counter > 0 { if counter > 0 {
p.buffer.Split() p.buffer.Split()
} }
processed = s.handleShortHeaderPacket(p) processed = s.handleShortHeaderPacket(p, destConnID)
break break
} }
} }
@ -929,7 +931,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
return processed return processed
} }
func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool { func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool {
var wasQueued bool var wasQueued bool
defer func() { defer func() {
@ -939,18 +941,18 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
} }
}() }()
hdr, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data) pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
if err != nil { if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT) wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT)
return false return false
} }
if s.logger.Debug() { if s.logger.Debug() {
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", hdr.PacketNumber, p.Size(), hdr.DestConnectionID) s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", pn, p.Size(), destConnID)
hdr.Log(s.logger) wire.LogShortHeader(s.logger, destConnID, pn, pnLen, keyPhase)
} }
if s.receivedPacketHandler.IsPotentiallyDuplicate(hdr.PacketNumber, protocol.Encryption1RTT) { if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) {
s.logger.Debugf("Dropping (potentially) duplicate packet.") s.logger.Debugf("Dropping (potentially) duplicate packet.")
if s.tracer != nil { if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate)
@ -958,7 +960,22 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
return false return false
} }
if err := s.handleUnpackedShortHeaderPacket(hdr, data, p.ecn, p.rcvTime, p.Size()); err != nil { var log func([]logging.Frame)
if s.tracer != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedShortHeaderPacket(
&logging.ShortHeader{
DestConnectionID: destConnID,
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: keyPhase,
},
p.Size(),
frames,
)
}
}
if err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log); err != nil {
s.closeLocal(err) s.closeLocal(err)
return false return false
} }
@ -1241,22 +1258,23 @@ func (s *connection) handleUnpackedPacket(
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
} }
func (s *connection) handleUnpackedShortHeaderPacket(hdr *wire.ShortHeader, data []byte, ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount) error { func (s *connection) handleUnpackedShortHeaderPacket(
destConnID protocol.ConnectionID,
pn protocol.PacketNumber,
data []byte,
ecn protocol.ECN,
rcvTime time.Time,
log func([]logging.Frame),
) error {
s.lastPacketReceivedTime = rcvTime s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false s.keepAlivePingSent = false
var log func([]logging.Frame) isAckEliciting, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log)
if s.tracer != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedShortHeaderPacket(hdr, packetSize, frames)
}
}
isAckEliciting, err := s.handleFrames(data, hdr.DestConnectionID, protocol.Encryption1RTT, log)
if err != nil { if err != nil {
return err return err
} }
return s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting) return s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting)
} }
func (s *connection) handleFrames( func (s *connection) handleFrames(

View file

@ -562,10 +562,10 @@ var _ = Describe("Connection", func() {
} }
Expect(hdr.Write(buf, conn.version)).To(Succeed()) Expect(hdr.Write(buf, conn.version)).To(Succeed())
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (*wire.ShortHeader, []byte, error) { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version) b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
return &wire.ShortHeader{PacketNumber: 3}, b, nil return 3, protocol.PacketNumberLen2, protocol.KeyPhaseOne, b, nil
}) })
gomock.InOrder( gomock.InOrder(
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()),
@ -766,7 +766,7 @@ var _ = Describe("Connection", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
packet := getPacket(hdr, nil) packet := getPacket(hdr, nil)
packet.ecn = protocol.ECT1 packet.ecn = protocol.ECT1
unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, b, nil) unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, b, nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
gomock.InOrder( gomock.InOrder(
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT), rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT),
@ -774,7 +774,7 @@ var _ = Describe("Connection", func() {
) )
conn.receivedPacketHandler = rph conn.receivedPacketHandler = rph
packet.rcvTime = rcvTime packet.rcvTime = rcvTime
tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) tracer.EXPECT().ReceivedShortHeaderPacket(&logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}})
Expect(conn.handlePacketImpl(packet)).To(BeTrue()) Expect(conn.handlePacketImpl(packet)).To(BeTrue())
}) })
@ -785,7 +785,7 @@ var _ = Describe("Connection", func() {
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
packet := getPacket(hdr, nil) packet := getPacket(hdr, nil)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, []byte("foobar"), nil) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseOne, []byte("foobar"), nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true)
conn.receivedPacketHandler = rph conn.receivedPacketHandler = rph
@ -829,11 +829,11 @@ var _ = Describe("Connection", func() {
It("processes multiple received packets before sending one", func() { It("processes multiple received packets before sending one", func() {
conn.creationTime = time.Now() conn.creationTime = time.Now()
var pn protocol.PacketNumber var pn protocol.PacketNumber
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
pn++ pn++
return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil return pn, protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil
}).Times(3) }).Times(3)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
}).Times(3) }).Times(3)
packer.EXPECT().PackCoalescedPacket() // only expect a single call packer.EXPECT().PackCoalescedPacket() // only expect a single call
@ -868,11 +868,11 @@ var _ = Describe("Connection", func() {
conn.handshakeComplete = false conn.handshakeComplete = false
conn.creationTime = time.Now() conn.creationTime = time.Now()
var pn protocol.PacketNumber var pn protocol.PacketNumber
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
pn++ pn++
return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil return pn, protocol.PacketNumberLen4, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil
}).Times(3) }).Times(3)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
}).Times(3) }).Times(3)
packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call
@ -904,7 +904,7 @@ var _ = Describe("Connection", func() {
}) })
It("closes the connection when unpacking fails because the reserved bits were incorrect", func() { It("closes the connection when unpacking fails because the reserved bits were incorrect", func() {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, wire.ErrInvalidReservedBits) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, wire.ErrInvalidReservedBits)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
@ -932,7 +932,7 @@ var _ = Describe("Connection", func() {
It("ignores packets when unpacking the header fails", func() { It("ignores packets when unpacking the header fails", func() {
testErr := &headerParseError{errors.New("test error")} testErr := &headerParseError{errors.New("test error")}
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, testErr) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, testErr)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
runErr := make(chan error) runErr := make(chan error)
@ -958,7 +958,7 @@ var _ = Describe("Connection", func() {
}) })
It("closes the connection when unpacking fails because of an error other than a decryption error", func() { It("closes the connection when unpacking fails because of an error other than a decryption error", func() {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
@ -1050,8 +1050,7 @@ var _ = Describe("Connection", func() {
Context("updating the remote address", func() { Context("updating the remote address", func() {
It("doesn't support connection migration", func() { It("doesn't support connection migration", func() {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{}, unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* one PADDING frame */, nil)
[]byte{0} /* one PADDING frame */, nil)
packet := getPacket(&wire.ExtendedHeader{ packet := getPacket(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID}, Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,

View file

@ -208,7 +208,7 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom
} }
// ReceivedShortHeaderPacket mocks base method. // ReceivedShortHeaderPacket mocks base method.
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *wire.ShortHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2) m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2)
} }

View file

@ -9,28 +9,20 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type ShortHeader struct { func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.PacketNumber, _ protocol.PacketNumberLen, _ protocol.KeyPhaseBit, _ error) {
DestConnectionID protocol.ConnectionID
PacketNumber protocol.PacketNumber
PacketNumberLen protocol.PacketNumberLen
KeyPhase protocol.KeyPhaseBit
}
func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) {
if len(data) == 0 { if len(data) == 0 {
return nil, io.EOF return 0, 0, 0, 0, io.EOF
} }
if data[0]&0x80 > 0 { if data[0]&0x80 > 0 {
return nil, errors.New("not a short header packet") return 0, 0, 0, 0, errors.New("not a short header packet")
} }
if data[0]&0x40 == 0 { if data[0]&0x40 == 0 {
return nil, errors.New("not a QUIC packet") return 0, 0, 0, 0, errors.New("not a QUIC packet")
} }
pnLen := protocol.PacketNumberLen(data[0]&0b11) + 1 pnLen := protocol.PacketNumberLen(data[0]&0b11) + 1
if len(data) < 1+int(pnLen)+connIDLen { if len(data) < 1+int(pnLen)+connIDLen {
return nil, io.EOF return 0, 0, 0, 0, io.EOF
} }
destConnID := protocol.ParseConnectionID(data[1 : 1+connIDLen])
pos := 1 + connIDLen pos := 1 + connIDLen
var pn protocol.PacketNumber var pn protocol.PacketNumber
@ -44,7 +36,7 @@ func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) {
case protocol.PacketNumberLen4: case protocol.PacketNumberLen4:
pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[pos : pos+4])) pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[pos : pos+4]))
default: default:
return nil, fmt.Errorf("invalid packet number length: %d", pnLen) return 0, 0, 0, 0, fmt.Errorf("invalid packet number length: %d", pnLen)
} }
kp := protocol.KeyPhaseZero kp := protocol.KeyPhaseZero
if data[0]&0b100 > 0 { if data[0]&0b100 > 0 {
@ -55,19 +47,9 @@ func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) {
if data[0]&0x18 != 0 { if data[0]&0x18 != 0 {
err = ErrInvalidReservedBits err = ErrInvalidReservedBits
} }
return &ShortHeader{ return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err
DestConnectionID: destConnID,
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: kp,
}, err
} }
func (h *ShortHeader) Len() protocol.ByteCount { func LogShortHeader(logger utils.Logger, dest protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) {
return 1 + protocol.ByteCount(h.DestConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", dest, pn, pnLen, kp)
}
// Log logs the Header
func (h *ShortHeader) Log(logger utils.Logger) {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
} }

View file

@ -21,12 +21,12 @@ var _ = Describe("Short Header", func() {
0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef,
0x13, 0x37, 0x99, 0x13, 0x37, 0x99,
} }
hdr, err := ParseShortHeader(data, 4) l, pn, pnLen, kp, err := ParseShortHeader(data, 4)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}))) Expect(l).To(Equal(len(data)))
Expect(hdr.KeyPhase).To(Equal(protocol.KeyPhaseOne)) Expect(kp).To(Equal(protocol.KeyPhaseOne))
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x133799))) Expect(pn).To(Equal(protocol.PacketNumber(0x133799)))
Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3)) Expect(pnLen).To(Equal(protocol.PacketNumberLen3))
}) })
It("errors when the QUIC bit is not set", func() { It("errors when the QUIC bit is not set", func() {
@ -35,7 +35,7 @@ var _ = Describe("Short Header", func() {
0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef,
0x13, 0x37, 0x13, 0x37,
} }
_, err := ParseShortHeader(data, 4) _, _, _, _, err := ParseShortHeader(data, 4)
Expect(err).To(MatchError("not a QUIC packet")) Expect(err).To(MatchError("not a QUIC packet"))
}) })
@ -45,14 +45,13 @@ var _ = Describe("Short Header", func() {
0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef,
0x13, 0x37, 0x13, 0x37,
} }
hdr, err := ParseShortHeader(data, 4) _, pn, _, _, err := ParseShortHeader(data, 4)
Expect(err).To(MatchError(ErrInvalidReservedBits)) Expect(err).To(MatchError(ErrInvalidReservedBits))
Expect(hdr).ToNot(BeNil()) Expect(pn).To(Equal(protocol.PacketNumber(0x1337)))
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337)))
}) })
It("errors when passed a long header packet", func() { It("errors when passed a long header packet", func() {
_, err := ParseShortHeader([]byte{0x80}, 4) _, _, _, _, err := ParseShortHeader([]byte{0x80}, 4)
Expect(err).To(MatchError("not a short header packet")) Expect(err).To(MatchError("not a short header packet"))
}) })
@ -62,10 +61,10 @@ var _ = Describe("Short Header", func() {
0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef,
0x13, 0x37, 0x99, 0x13, 0x37, 0x99,
} }
_, err := ParseShortHeader(data, 4) _, _, _, _, err := ParseShortHeader(data, 4)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := range data { for i := range data {
_, err := ParseShortHeader(data[:i], 4) _, _, _, _, err := ParseShortHeader(data[:i], 4)
Expect(err).To(MatchError(io.EOF)) Expect(err).To(MatchError(io.EOF))
} }
}) })
@ -89,22 +88,9 @@ var _ = Describe("Short Header", func() {
}) })
It("logs Short Headers containing a connection ID", func() { It("logs Short Headers containing a connection ID", func() {
(&ShortHeader{ connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), LogShortHeader(logger, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
KeyPhase: protocol.KeyPhaseOne,
PacketNumber: 1337,
PacketNumberLen: 4,
}).Log(logger)
Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}")) Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}"))
}) })
}) })
It("determines the length", func() {
Expect((&ShortHeader{
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xaf}),
PacketNumber: 0x1337,
PacketNumberLen: protocol.PacketNumberLen3,
KeyPhase: protocol.KeyPhaseOne,
}).Len()).To(Equal(protocol.ByteCount(1 + 2 + 3)))
})
}) })

View file

@ -46,8 +46,6 @@ type (
Header = wire.Header Header = wire.Header
// The ExtendedHeader is the QUIC Long Header packet header, after removing header protection. // The ExtendedHeader is the QUIC Long Header packet header, after removing header protection.
ExtendedHeader = wire.ExtendedHeader ExtendedHeader = wire.ExtendedHeader
// The ShortHeader is the QUIC Short Header packet header, after removing header protection.
ShortHeader = wire.ShortHeader
// The TransportParameters are QUIC transport parameters. // The TransportParameters are QUIC transport parameters.
TransportParameters = wire.TransportParameters TransportParameters = wire.TransportParameters
// The PreferredAddress is the preferred address sent in the transport parameters. // The PreferredAddress is the preferred address sent in the transport parameters.
@ -94,6 +92,14 @@ const (
StreamTypeBidi = protocol.StreamTypeBidi StreamTypeBidi = protocol.StreamTypeBidi
) )
// The ShortHeader is the QUIC Short Header packet header, after removing header protection.
type ShortHeader struct {
DestConnectionID ConnectionID
PacketNumber PacketNumber
PacketNumberLen protocol.PacketNumberLen
KeyPhase KeyPhaseBit
}
// A Tracer traces events. // A Tracer traces events.
type Tracer interface { type Tracer interface {
// TracerForConnection requests a new tracer for a connection. // TracerForConnection requests a new tracer for a connection.

View file

@ -207,7 +207,7 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom
} }
// ReceivedShortHeaderPacket mocks base method. // ReceivedShortHeaderPacket mocks base method.
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *wire.ShortHeader, arg1 protocol.ByteCount, arg2 []Frame) { func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 []Frame) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2) m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2)
} }

View file

@ -9,6 +9,7 @@ import (
time "time" time "time"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire" wire "github.com/lucas-clemente/quic-go/internal/wire"
) )
@ -51,13 +52,15 @@ func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, rcvTime, data interfac
} }
// UnpackShortHeader mocks base method. // UnpackShortHeader mocks base method.
func (m *MockUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { func (m *MockUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnpackShortHeader", rcvTime, data) ret := m.ctrl.Call(m, "UnpackShortHeader", rcvTime, data)
ret0, _ := ret[0].(*wire.ShortHeader) ret0, _ := ret[0].(protocol.PacketNumber)
ret1, _ := ret[1].([]byte) ret1, _ := ret[1].(protocol.PacketNumberLen)
ret2, _ := ret[2].(error) ret2, _ := ret[2].(protocol.KeyPhaseBit)
return ret0, ret1, ret2 ret3, _ := ret[3].([]byte)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
} }
// UnpackShortHeader indicates an expected call of UnpackShortHeader. // UnpackShortHeader indicates an expected call of UnpackShortHeader.

View file

@ -109,22 +109,22 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
}, nil }, nil
} }
func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
opener, err := u.cs.Get1RTTOpener() opener, err := u.cs.Get1RTTOpener()
if err != nil { if err != nil {
return nil, nil, err return 0, 0, 0, nil, err
} }
hdr, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data) pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
if err != nil { if err != nil {
return nil, nil, err return 0, 0, 0, nil, err
} }
if len(decrypted) == 0 { if len(decrypted) == 0 {
return nil, nil, &qerr.TransportError{ return 0, 0, 0, nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation, ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet", ErrorMessage: "empty packet",
} }
} }
return hdr, decrypted, nil return pn, pnLen, kp, decrypted, nil
} }
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
@ -147,27 +147,26 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene
return extHdr, decrypted, nil return extHdr, decrypted, nil
} }
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
hdr, parseErr := u.unpackShortHeader(opener, data) l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking. // If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker // This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption. // to gain information about the header encryption.
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, nil, &headerParseError{parseErr} return 0, 0, 0, nil, &headerParseError{parseErr}
} }
hdr.PacketNumber = opener.DecodePacketNumber(hdr.PacketNumber, hdr.PacketNumberLen) pn = opener.DecodePacketNumber(pn, pnLen)
l := hdr.Len() decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l])
decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, hdr.PacketNumber, hdr.KeyPhase, data[:l])
if err != nil { if err != nil {
return nil, nil, err return 0, 0, 0, nil, err
} }
return hdr, decrypted, parseErr return pn, pnLen, kp, decrypted, parseErr
} }
func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (*wire.ShortHeader, error) { func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) {
hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen
if len(data) < hdrLen+4+16 { if len(data) < hdrLen+4+16 {
return nil, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen) return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
} }
origPNBytes := make([]byte, 4) origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4]) copy(origPNBytes, data[hdrLen:hdrLen+4])
@ -178,15 +177,15 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (*wi
data[hdrLen:hdrLen+4], data[hdrLen:hdrLen+4],
) )
// 3. parse the header (and learn the actual length of the packet number) // 3. parse the header (and learn the actual length of the packet number)
hdr, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen) l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr return l, pn, pnLen, kp, parseErr
} }
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if hdr.PacketNumberLen != protocol.PacketNumberLen4 { if pnLen != protocol.PacketNumberLen4 {
copy(data[hdrLen+int(hdr.PacketNumberLen):hdrLen+4], origPNBytes[int(hdr.PacketNumberLen):]) copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):])
} }
return hdr, parseErr return l, pn, pnLen, kp, parseErr
} }
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.

View file

@ -75,7 +75,7 @@ var _ = Describe("Packet Unpacker", func() {
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, _, err := unpacker.UnpackShortHeader(time.Now(), data) _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), data)
Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19")) Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19"))
}) })
@ -148,10 +148,11 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)), opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil), opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil),
) )
hdr, data, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...)) pn, pnLen, kp, data, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(321))) Expect(pn).To(Equal(protocol.PacketNumber(321)))
Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) Expect(pnLen).To(Equal(protocol.PacketNumberLen4))
Expect(kp).To(Equal(protocol.KeyPhaseOne))
Expect(data).To(Equal([]byte("decrypted"))) Expect(data).To(Equal([]byte("decrypted")))
}) })
@ -163,7 +164,7 @@ var _ = Describe("Packet Unpacker", func() {
} }
_, hdrRaw := getHeader(extHdr) _, hdrRaw := getHeader(extHdr)
cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable) cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable)
_, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable)) Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable))
}) })
@ -208,7 +209,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)), opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte(""), nil), opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte(""), nil),
) )
_, _, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...)) _, _, _, _, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
Expect(err).To(MatchError(&qerr.TransportError{ Expect(err).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation, ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet", ErrorMessage: "empty packet",
@ -273,7 +274,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
}) })
@ -312,7 +313,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
_, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
}) })

View file

@ -319,11 +319,12 @@ func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, p
fs[i] = frame{Frame: f} fs[i] = frame{Frame: f}
} }
header := *transformShortHeader(hdr) header := *transformShortHeader(hdr)
hdrLen := 1 + hdr.DestConnectionID.Len() + int(hdr.PacketNumberLen)
t.mutex.Lock() t.mutex.Lock()
t.recordEvent(time.Now(), &eventPacketReceived{ t.recordEvent(time.Now(), &eventPacketReceived{
Header: header, Header: header,
Length: packetSize, Length: packetSize,
PayloadLength: packetSize - hdr.Len(), PayloadLength: packetSize - protocol.ByteCount(hdrLen),
Frames: fs, Frames: fs,
}) })
t.mutex.Unlock() t.mutex.Unlock()

View file

@ -11,8 +11,6 @@ import (
"os" "os"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qerr"
@ -529,9 +527,10 @@ var _ = Describe("Tracing", func() {
}) })
It("records a received Short Header packet", func() { It("records a received Short Header packet", func() {
shdr := &wire.ShortHeader{ shdr := &logging.ShortHeader{
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
PacketNumber: 1337, PacketNumber: 1337,
PacketNumberLen: protocol.PacketNumberLen3,
KeyPhase: protocol.KeyPhaseZero, KeyPhase: protocol.KeyPhaseZero,
} }
tracer.ReceivedShortHeaderPacket( tracer.ReceivedShortHeaderPacket(
@ -549,7 +548,7 @@ var _ = Describe("Tracing", func() {
Expect(ev).To(HaveKey("raw")) Expect(ev).To(HaveKey("raw"))
raw := ev["raw"].(map[string]interface{}) raw := ev["raw"].(map[string]interface{})
Expect(raw).To(HaveKeyWithValue("length", float64(789))) Expect(raw).To(HaveKeyWithValue("length", float64(789)))
Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-shdr.Len()))) Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-(1+8+3))))
Expect(ev).To(HaveKey("header")) Expect(ev).To(HaveKey("header"))
hdr := ev["header"].(map[string]interface{}) hdr := ev["header"].(map[string]interface{})
Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT")) Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT"))