introduce a separate code path for unpacking short header packets

This commit is contained in:
Marten Seemann 2022-08-27 22:30:09 +03:00
parent ed15a94703
commit 4f3d3b36ac
7 changed files with 390 additions and 245 deletions

View file

@ -25,7 +25,8 @@ import (
) )
type unpacker interface { type unpacker interface {
Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error)
} }
type streamGetter interface { type streamGetter interface {
@ -362,7 +363,7 @@ var newConnection = func(
s.perspective, s.perspective,
s.version, s.version,
) )
s.unpacker = newPacketUnpacker(cs, s.version) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
return s return s
} }
@ -474,7 +475,7 @@ var newClientConnection = func(
s.clientHelloWritten = clientHelloWritten s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
s.unpacker = newPacketUnpacker(cs, s.version) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version)
s.packer = newPacketPacker( s.packer = newPacketPacker(
srcConnID, srcConnID,
s.connIDManager.Get, s.connIDManager.Get,
@ -858,8 +859,25 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
if counter > 0 { if counter > 0 {
p = p.Clone() p = p.Clone()
p.data = data p.data = data
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err)
break
}
if destConnID != lastConnID {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
}
s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID)
break
}
} }
if wire.IsLongHeaderPacket(p.data[0]) {
hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen) hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil {
@ -872,8 +890,9 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
s.logger.Debugf("error parsing packet: %s", err) s.logger.Debugf("error parsing packet: %s", err)
break break
} }
lastConnID = hdr.DestConnectionID
if hdr.IsLongHeader && hdr.Version != s.version { if hdr.Version != s.version {
if s.tracer != nil { if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion)
} }
@ -881,15 +900,6 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
break break
} }
if counter > 0 && hdr.DestConnectionID != lastConnID {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
}
s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
break
}
lastConnID = hdr.DestConnectionID
if counter > 0 { if counter > 0 {
p.buffer.Split() p.buffer.Split()
} }
@ -899,17 +909,63 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
if s.logger.Debug() && (counter > 1 || len(rest) > 0) { if s.logger.Debug() && (counter > 1 || len(rest) > 0) {
s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest))
} }
p.data = packetData p.data = packetData
if wasProcessed := s.handleSinglePacket(p, hdr); wasProcessed {
if wasProcessed := s.handleLongHeaderPacket(p, hdr); wasProcessed {
processed = true processed = true
} }
data = rest data = rest
} else {
if counter > 0 {
p.buffer.Split()
} }
processed = s.handleShortHeaderPacket(p)
break
}
}
p.buffer.MaybeRelease() p.buffer.MaybeRelease()
return processed return processed
} }
func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
var wasQueued bool
defer func() {
// Put back the packet buffer if the packet wasn't queued for later decryption.
if !wasQueued {
p.buffer.Decrement()
}
}()
hdr, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT)
return false
}
if s.logger.Debug() {
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", hdr.PacketNumber, p.Size(), hdr.DestConnectionID)
hdr.Log(s.logger)
}
if s.receivedPacketHandler.IsPotentiallyDuplicate(hdr.PacketNumber, protocol.Encryption1RTT) {
s.logger.Debugf("Dropping (potentially) duplicate packet.")
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate)
}
return false
}
if err := s.handleUnpackedShortHeaderPacket(hdr, data, p.ecn, p.rcvTime, p.Size()); err != nil {
s.closeLocal(err)
return false
}
return true
}
func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
var wasQueued bool var wasQueued bool
defer func() { defer func() {
@ -925,7 +981,7 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo
// The server can change the source connection ID with the first Handshake packet. // The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored. // After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID {
if s.tracer != nil { if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID)
} }
@ -940,44 +996,9 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo
return false return false
} }
packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data) packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data)
if err != nil { if err != nil {
switch err { wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
case handshake.ErrKeysDropped:
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropKeyUnavailable)
}
s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", hdr.PacketType(), p.Size())
case handshake.ErrKeysNotYetAvailable:
// Sealer for this encryption level not yet available.
// Try again later.
wasQueued = true
s.tryQueueingUndecryptablePacket(p, logging.PacketTypeFromHeader(hdr))
case wire.ErrInvalidReservedBits:
s.closeLocal(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: err.Error(),
})
case handshake.ErrDecryptionFailed:
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropPayloadDecryptError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", hdr.PacketType(), p.Size(), err)
default:
var headerErr *headerParseError
if errors.As(err, &headerErr) {
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", hdr.PacketType(), p.Size(), err)
} else {
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
// For example, a PROTOCOL_VIOLATION due to key updates.
s.closeLocal(err)
}
}
return false return false
} }
@ -1001,6 +1022,46 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo
return true return true
} }
func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) {
switch err {
case handshake.ErrKeysDropped:
if s.tracer != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable)
}
s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size())
case handshake.ErrKeysNotYetAvailable:
// Sealer for this encryption level not yet available.
// Try again later.
s.tryQueueingUndecryptablePacket(p, pt)
return true
case wire.ErrInvalidReservedBits:
s.closeLocal(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: err.Error(),
})
case handshake.ErrDecryptionFailed:
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err)
default:
var headerErr *headerParseError
if errors.As(err, &headerErr) {
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err)
} else {
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
// For example, a PROTOCOL_VIOLATION due to key updates.
s.closeLocal(err)
}
}
return false
}
func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ {
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
if s.tracer != nil { if s.tracer != nil {
@ -1167,15 +1228,51 @@ func (s *connection) handleUnpackedPacket(
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false s.keepAlivePingSent = false
var log func([]logging.Frame)
if s.tracer != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames)
}
}
isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log)
if err != nil {
return err
}
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
}
func (s *connection) handleUnpackedShortHeaderPacket(hdr *wire.ShortHeader, data []byte, ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount) error {
s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false
var log func([]logging.Frame)
if s.tracer != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedShortHeaderPacket(hdr, packetSize, frames)
}
}
isAckEliciting, err := s.handleFrames(data, hdr.DestConnectionID, protocol.Encryption1RTT, log)
if err != nil {
return err
}
return s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting)
}
func (s *connection) handleFrames(
data []byte,
destConnID protocol.ConnectionID,
encLevel protocol.EncryptionLevel,
log func([]logging.Frame),
) (isAckEliciting bool, _ error) {
// Only used for tracing. // Only used for tracing.
// If we're not tracing, this slice will always remain empty. // If we're not tracing, this slice will always remain empty.
var frames []wire.Frame var frames []wire.Frame
r := bytes.NewReader(packet.data) r := bytes.NewReader(data)
var isAckEliciting bool
for { for {
frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel) frame, err := s.frameParser.ParseNext(r, encLevel)
if err != nil { if err != nil {
return err return false, err
} }
if frame == nil { if frame == nil {
break break
@ -1185,38 +1282,28 @@ func (s *connection) handleUnpackedPacket(
} }
// Only process frames now if we're not logging. // Only process frames now if we're not logging.
// If we're logging, we need to make sure that the packet_received event is logged first. // If we're logging, we need to make sure that the packet_received event is logged first.
if s.tracer == nil { if log == nil {
if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
return err return false, err
} }
} else { } else {
frames = append(frames, frame) frames = append(frames, frame)
} }
} }
if s.tracer != nil { if log != nil {
fs := make([]logging.Frame, len(frames)) fs := make([]logging.Frame, len(frames))
for i, frame := range frames { for i, frame := range frames {
fs[i] = logutils.ConvertFrame(frame) fs[i] = logutils.ConvertFrame(frame)
} }
if packet.hdr.IsLongHeader { log(fs)
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, fs)
} else {
s.tracer.ReceivedShortHeaderPacket(&wire.ShortHeader{
DestConnectionID: packet.hdr.DestConnectionID,
PacketNumber: packet.hdr.PacketNumber,
PacketNumberLen: packet.hdr.PacketNumberLen,
KeyPhase: packet.hdr.KeyPhase,
}, packetSize, fs)
}
for _, frame := range frames { for _, frame := range frames {
if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
return err return false, err
} }
} }
} }
return
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
} }
func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error {

View file

@ -561,17 +561,13 @@ var _ = Describe("Connection", func() {
PacketNumberLen: protocol.PacketNumberLen2, PacketNumberLen: protocol.PacketNumberLen2,
} }
Expect(hdr.Write(buf, conn.version)).To(Succeed()) Expect(hdr.Write(buf, conn.version)).To(Succeed())
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (*wire.ShortHeader, []byte, error) {
b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version) b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
return &unpackedPacket{ return &wire.ShortHeader{PacketNumber: 3}, b, nil
hdr: hdr,
data: b,
encryptionLevel: protocol.Encryption1RTT,
}, nil
}) })
gomock.InOrder( gomock.InOrder(
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()),
tracer.EXPECT().ClosedConnection(gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()),
tracer.EXPECT().Close(), tracer.EXPECT().Close(),
@ -727,7 +723,13 @@ var _ = Describe("Connection", func() {
It("informs the ReceivedPacketHandler about non-ack-eliciting packets", func() { It("informs the ReceivedPacketHandler about non-ack-eliciting packets", func() {
hdr := &wire.ExtendedHeader{ hdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID}, Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: srcConnID,
Version: protocol.Version1,
Length: 1,
},
PacketNumber: 0x37, PacketNumber: 0x37,
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
@ -736,7 +738,7 @@ var _ = Describe("Connection", func() {
packet := getPacket(hdr, nil) packet := getPacket(hdr, nil)
packet.ecn = protocol.ECNCE packet.ecn = protocol.ECNCE
rcvTime := time.Now().Add(-10 * time.Second) rcvTime := time.Now().Add(-10 * time.Second)
unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().UnpackLongHeader(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial, encryptionLevel: protocol.EncryptionInitial,
hdr: &unpackedHdr, hdr: &unpackedHdr,
data: []byte{0}, // one PADDING frame data: []byte{0}, // one PADDING frame
@ -749,7 +751,7 @@ var _ = Describe("Connection", func() {
conn.receivedPacketHandler = rph conn.receivedPacketHandler = rph
packet.rcvTime = rcvTime packet.rcvTime = rcvTime
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 1, DestConnectionID: srcConnID}, protocol.ByteCount(len(packet.data)), []logging.Frame{}) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), []logging.Frame{})
Expect(conn.handlePacketImpl(packet)).To(BeTrue()) Expect(conn.handlePacketImpl(packet)).To(BeTrue())
}) })
@ -759,18 +761,12 @@ var _ = Describe("Connection", func() {
PacketNumber: 0x37, PacketNumber: 0x37,
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
unpackedHdr := *hdr
unpackedHdr.PacketNumber = 0x1337
rcvTime := time.Now().Add(-10 * time.Second) rcvTime := time.Now().Add(-10 * time.Second)
b, err := (&wire.PingFrame{}).Append(nil, conn.version) b, err := (&wire.PingFrame{}).Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
packet := getPacket(hdr, nil) packet := getPacket(hdr, nil)
packet.ecn = protocol.ECT1 packet.ecn = protocol.ECT1
unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, b, nil)
encryptionLevel: protocol.Encryption1RTT,
hdr: &unpackedHdr,
data: b,
}, nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
gomock.InOrder( gomock.InOrder(
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT), rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT),
@ -778,8 +774,7 @@ var _ = Describe("Connection", func() {
) )
conn.receivedPacketHandler = rph conn.receivedPacketHandler = rph
packet.rcvTime = rcvTime packet.rcvTime = rcvTime
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}})
tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 1, DestConnectionID: srcConnID}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}})
Expect(conn.handlePacketImpl(packet)).To(BeTrue()) Expect(conn.handlePacketImpl(packet)).To(BeTrue())
}) })
@ -790,15 +785,7 @@ var _ = Describe("Connection", func() {
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
packet := getPacket(hdr, nil) packet := getPacket(hdr, nil)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, []byte("foobar"), nil)
h := *hdr
h.PacketNumber = 0x1337
return &unpackedPacket{
encryptionLevel: protocol.Encryption1RTT,
hdr: &h,
data: []byte("foobar"),
}, nil
})
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true)
conn.receivedPacketHandler = rph conn.receivedPacketHandler = rph
@ -807,7 +794,7 @@ var _ = Describe("Connection", func() {
}) })
It("drops a packet when unpacking fails", func() { It("drops a packet when unpacking fails", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
@ -842,18 +829,10 @@ var _ = Describe("Connection", func() {
It("processes multiple received packets before sending one", func() { It("processes multiple received packets before sending one", func() {
conn.creationTime = time.Now() conn.creationTime = time.Now()
var pn protocol.PacketNumber var pn protocol.PacketNumber
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
pn++ pn++
return &unpackedPacket{ return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil
data: []byte{0}, // PADDING frame
encryptionLevel: protocol.Encryption1RTT,
hdr: &wire.ExtendedHeader{
Header: *hdr,
PacketNumber: pn,
},
}, nil
}).Times(3) }).Times(3)
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
}).Times(3) }).Times(3)
packer.EXPECT().PackCoalescedPacket() // only expect a single call packer.EXPECT().PackCoalescedPacket() // only expect a single call
@ -889,18 +868,10 @@ var _ = Describe("Connection", func() {
conn.handshakeComplete = false conn.handshakeComplete = false
conn.creationTime = time.Now() conn.creationTime = time.Now()
var pn protocol.PacketNumber var pn protocol.PacketNumber
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
pn++ pn++
return &unpackedPacket{ return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil
data: []byte{0}, // PADDING frame
encryptionLevel: protocol.Encryption1RTT,
hdr: &wire.ExtendedHeader{
Header: *hdr,
PacketNumber: pn,
},
}, nil
}).Times(3) }).Times(3)
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
}).Times(3) }).Times(3)
packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call
@ -933,7 +904,7 @@ var _ = Describe("Connection", func() {
}) })
It("closes the connection when unpacking fails because the reserved bits were incorrect", func() { It("closes the connection when unpacking fails because the reserved bits were incorrect", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, wire.ErrInvalidReservedBits)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
@ -961,7 +932,7 @@ var _ = Describe("Connection", func() {
It("ignores packets when unpacking the header fails", func() { It("ignores packets when unpacking the header fails", func() {
testErr := &headerParseError{errors.New("test error")} testErr := &headerParseError{errors.New("test error")}
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, testErr)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
runErr := make(chan error) runErr := make(chan error)
@ -987,7 +958,7 @@ var _ = Describe("Connection", func() {
}) })
It("closes the connection when unpacking fails because of an error other than a decryption error", func() { It("closes the connection when unpacking fails because of an error other than a decryption error", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
@ -1041,7 +1012,7 @@ var _ = Describe("Connection", func() {
Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
// Send one packet, which might change the connection ID. // Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker // only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.Encryption1RTT, encryptionLevel: protocol.Encryption1RTT,
hdr: hdr1, hdr: hdr1,
data: []byte{0}, // one PADDING frame data: []byte{0}, // one PADDING frame
@ -1070,7 +1041,7 @@ var _ = Describe("Connection", func() {
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 1, PacketNumber: 1,
} }
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable) unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable)
packet := getPacket(hdr, nil) packet := getPacket(hdr, nil)
tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake) tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake)
Expect(conn.handlePacketImpl(packet)).To(BeFalse()) Expect(conn.handlePacketImpl(packet)).To(BeFalse())
@ -1079,17 +1050,13 @@ var _ = Describe("Connection", func() {
Context("updating the remote address", func() { Context("updating the remote address", func() {
It("doesn't support connection migration", func() { It("doesn't support connection migration", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{},
encryptionLevel: protocol.Encryption1RTT, []byte{0} /* one PADDING frame */, nil)
hdr: &wire.ExtendedHeader{},
data: []byte{0}, // one PADDING frame
}, nil)
packet := getPacket(&wire.ExtendedHeader{ packet := getPacket(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID}, Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
}, nil) }, nil)
packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
Expect(conn.handlePacketImpl(packet)).To(BeTrue()) Expect(conn.handlePacketImpl(packet)).To(BeTrue())
}) })
@ -1120,7 +1087,7 @@ var _ = Describe("Connection", func() {
It("cuts packets to the right length", func() { It("cuts packets to the right length", func() {
hdrLen, packet := getPacketWithLength(srcConnID, 456) hdrLen, packet := getPacketWithLength(srcConnID, 456)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen + 456 - 3)) Expect(data).To(HaveLen(hdrLen + 456 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1134,7 +1101,7 @@ var _ = Describe("Connection", func() {
It("handles coalesced packets", func() { It("handles coalesced packets", func() {
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1146,7 +1113,7 @@ var _ = Describe("Connection", func() {
}, nil }, nil
}) })
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1170,8 +1137,8 @@ var _ = Describe("Connection", func() {
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
gomock.InOrder( gomock.InOrder(
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1195,7 +1162,7 @@ var _ = Describe("Connection", func() {
wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
Expect(srcConnID).ToNot(Equal(wrongConnID)) Expect(srcConnID).ToNot(Equal(wrongConnID))
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1204,7 +1171,7 @@ var _ = Describe("Connection", func() {
}, nil }, nil
}) })
_, packet2 := getPacketWithLength(wrongConnID, 123) _, packet2 := getPacketWithLength(wrongConnID, 123)
// don't EXPECT any more calls to unpacker.Unpack() // don't EXPECT any more calls to unpacker.UnpackLongHeader()
gomock.InOrder( gomock.InOrder(
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()),
tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID),
@ -2456,7 +2423,7 @@ var _ = Describe("Client Connection", func() {
It("changes the connection ID when receiving the first packet from the server", func() { It("changes the connection ID when receiving the first packet from the server", func() {
unpacker := NewMockUnpacker(mockCtrl) unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.Encryption1RTT, encryptionLevel: protocol.Encryption1RTT,
hdr: &wire.ExtendedHeader{Header: *hdr}, hdr: &wire.ExtendedHeader{Header: *hdr},
@ -2505,7 +2472,7 @@ var _ = Describe("Client Connection", func() {
}) })
Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))) Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})))
// now receive a packet with the original source connection ID // now receive a packet with the original source connection ID
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) {
return &unpackedPacket{ return &unpackedPacket{
hdr: &wire.ExtendedHeader{Header: *hdr}, hdr: &wire.ExtendedHeader{Header: *hdr},
data: []byte{0}, data: []byte{0},
@ -2519,7 +2486,7 @@ var _ = Describe("Client Connection", func() {
SrcConnectionID: destConnID, SrcConnectionID: destConnID,
} }
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) Expect(conn.handleLongHeaderPacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue())
}) })
It("handles HANDSHAKE_DONE frames", func() { It("handles HANDSHAKE_DONE frames", func() {
@ -2946,7 +2913,7 @@ var _ = Describe("Client Connection", func() {
Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID)) Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID))
// Send one packet, which might change the connection ID. // Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker // only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial, encryptionLevel: protocol.EncryptionInitial,
hdr: hdr1, hdr: hdr1,
data: []byte{0}, // one PADDING frame data: []byte{0}, // one PADDING frame

View file

@ -35,17 +35,33 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
return m.recorder return m.recorder
} }
// Unpack mocks base method. // UnpackLongHeader mocks base method.
func (m *MockUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { func (m *MockUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Unpack", hdr, rcvTime, data) ret := m.ctrl.Call(m, "UnpackLongHeader", hdr, rcvTime, data)
ret0, _ := ret[0].(*unpackedPacket) ret0, _ := ret[0].(*unpackedPacket)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// Unpack indicates an expected call of Unpack. // UnpackLongHeader indicates an expected call of UnpackLongHeader.
func (mr *MockUnpackerMockRecorder) Unpack(hdr, rcvTime, data interface{}) *gomock.Call { func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, rcvTime, data interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), hdr, rcvTime, data) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), hdr, rcvTime, data)
}
// UnpackShortHeader mocks base method.
func (m *MockUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnpackShortHeader", rcvTime, data)
ret0, _ := ret[0].(*wire.ShortHeader)
ret1, _ := ret[1].([]byte)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// UnpackShortHeader indicates an expected call of UnpackShortHeader.
func (mr *MockUnpackerMockRecorder) UnpackShortHeader(rcvTime, data interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackShortHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackShortHeader), rcvTime, data)
} }

View file

@ -37,22 +37,25 @@ type unpackedPacket struct {
type packetUnpacker struct { type packetUnpacker struct {
cs handshake.CryptoSetup cs handshake.CryptoSetup
shortHdrConnIDLen int
version protocol.VersionNumber version protocol.VersionNumber
} }
var _ unpacker = &packetUnpacker{} var _ unpacker = &packetUnpacker{}
func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker { func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int, version protocol.VersionNumber) unpacker {
return &packetUnpacker{ return &packetUnpacker{
cs: cs, cs: cs,
shortHdrConnIDLen: shortHdrConnIDLen,
version: version, version: version,
} }
} }
// UnpackLongHeader unpacks a Long Header packet.
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
// If any other error occurred when parsing the header, the error is of type headerParseError. // If any other error occurred when parsing the header, the error is of type headerParseError.
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader var extHdr *wire.ExtendedHeader
var decrypted []byte var decrypted []byte
@ -89,19 +92,8 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte
return nil, err return nil, err
} }
default: default:
if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
} }
encLevel = protocol.Encryption1RTT
opener, err := u.cs.Get1RTTOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, rcvTime, data)
if err != nil {
return nil, err
}
}
if len(decrypted) == 0 { if len(decrypted) == 0 {
return nil, &qerr.TransportError{ return nil, &qerr.TransportError{
@ -117,8 +109,26 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte
}, nil }, nil
} }
func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
opener, err := u.cs.Get1RTTOpener()
if err != nil {
return nil, nil, err
}
hdr, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
if err != nil {
return nil, nil, err
}
if len(decrypted) == 0 {
return nil, nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet",
}
}
return hdr, decrypted, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackHeader(opener, hdr, data) extHdr, parseErr := u.unpackLongHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking. // If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker // This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption. // to gain information about the header encryption.
@ -137,41 +147,58 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene
return extHdr, decrypted, nil return extHdr, decrypted, nil
} }
func (u *packetUnpacker) unpackShortHeaderPacket( func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
opener handshake.ShortHeaderOpener, hdr, parseErr := u.unpackShortHeader(opener, data)
hdr *wire.Header,
rcvTime time.Time,
data []byte,
) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking. // If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker // This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption. // to gain information about the header encryption.
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, nil, parseErr return nil, nil, &headerParseError{parseErr}
} }
extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) hdr.PacketNumber = opener.DecodePacketNumber(hdr.PacketNumber, hdr.PacketNumberLen)
extHdrLen := extHdr.ParsedLen() l := hdr.Len()
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, hdr.PacketNumber, hdr.KeyPhase, data[:l])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if parseErr != nil { return hdr, decrypted, parseErr
return nil, nil, parseErr
} }
return extHdr, decrypted, nil
func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (*wire.ShortHeader, error) {
hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen
if len(data) < hdrLen+4+16 {
return nil, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
}
origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number
hd.DecryptHeader(
data[hdrLen+4:hdrLen+4+16],
&data[0],
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
hdr, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr
}
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if hdr.PacketNumberLen != protocol.PacketNumberLen4 {
copy(data[hdrLen+int(hdr.PacketNumberLen):hdrLen+4], origPNBytes[int(hdr.PacketNumberLen):])
}
return hdr, parseErr
} }
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
extHdr, err := unpackHeader(hd, hdr, data, u.version) extHdr, err := unpackLongHeader(hd, hdr, data, u.version)
if err != nil && err != wire.ErrInvalidReservedBits { if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err} return nil, &headerParseError{err: err}
} }
return extHdr, err return extHdr, err
} }
func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) { func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data) r := bytes.NewReader(data)
hdrLen := hdr.ParsedLen() hdrLen := hdr.ParsedLen()

View file

@ -41,7 +41,7 @@ var _ = Describe("Packet Unpacker", func() {
BeforeEach(func() { BeforeEach(func() {
cs = mocks.NewMockCryptoSetup(mockCtrl) cs = mocks.NewMockCryptoSetup(mockCtrl)
unpacker = newPacketUnpacker(cs, version).(*packetUnpacker) unpacker = newPacketUnpacker(cs, 4, version).(*packetUnpacker)
}) })
It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() { It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() {
@ -59,7 +59,7 @@ var _ = Describe("Packet Unpacker", func() {
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockLongHeaderOpener(mockCtrl) opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
_, err := unpacker.Unpack(hdr, time.Now(), data) _, err := unpacker.UnpackLongHeader(hdr, time.Now(), data)
Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
var headerErr *headerParseError var headerErr *headerParseError
Expect(errors.As(err, &headerErr)).To(BeTrue()) Expect(errors.As(err, &headerErr)).To(BeTrue())
@ -67,18 +67,17 @@ var _ = Describe("Packet Unpacker", func() {
}) })
It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() { It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() {
extHdr := &wire.ExtendedHeader{ _, hdrRaw := getHeader(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID}, Header: wire.Header{DestConnectionID: connID},
PacketNumber: 1337, PacketNumber: 1337,
PacketNumberLen: protocol.PacketNumberLen2, PacketNumberLen: protocol.PacketNumberLen2,
} })
hdr, hdrRaw := getHeader(extHdr)
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, err := unpacker.Unpack(hdr, time.Now(), data) _, _, err := unpacker.UnpackShortHeader(time.Now(), data)
Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19"))
}) })
It("opens Initial packets", func() { It("opens Initial packets", func() {
@ -101,7 +100,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)), opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil), opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
) )
packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expect(packet.data).To(Equal([]byte("decrypted"))) Expect(packet.data).To(Equal([]byte("decrypted")))
@ -127,7 +126,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)), opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil), opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
) )
packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT)) Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
Expect(packet.data).To(Equal([]byte("decrypted"))) Expect(packet.data).To(Equal([]byte("decrypted")))
@ -140,7 +139,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 99, PacketNumber: 99,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
} }
hdr, hdrRaw := getHeader(extHdr) _, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
now := time.Now() now := time.Now()
gomock.InOrder( gomock.InOrder(
@ -149,10 +148,11 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)), opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil), opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil),
) )
packet, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...)) hdr, data, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption1RTT)) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(321)))
Expect(packet.data).To(Equal([]byte("decrypted"))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
Expect(data).To(Equal([]byte("decrypted")))
}) })
It("returns the error when getting the opener fails", func() { It("returns the error when getting the opener fails", func() {
@ -161,19 +161,45 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
} }
hdr, hdrRaw := getHeader(extHdr) _, hdrRaw := getHeader(extHdr)
cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable) cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable)) Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable))
}) })
It("errors on empty packets", func() { It("errors on empty packets, for long header packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Version: Version1,
},
KeyPhase: protocol.KeyPhaseOne,
PacketNumberLen: protocol.PacketNumberLen4,
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
gomock.InOrder(
cs.EXPECT().GetHandshakeOpener().Return(opener, nil),
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte(""), nil),
)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet",
}))
})
It("errors on empty packets, for short header packets", func() {
extHdr := &wire.ExtendedHeader{ extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID}, Header: wire.Header{DestConnectionID: connID},
KeyPhase: protocol.KeyPhaseOne, KeyPhase: protocol.KeyPhaseOne,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
} }
hdr, hdrRaw := getHeader(extHdr) _, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
now := time.Now() now := time.Now()
gomock.InOrder( gomock.InOrder(
@ -182,7 +208,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)), opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte(""), nil), opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte(""), nil),
) )
_, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...)) _, _, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
Expect(err).To(MatchError(&qerr.TransportError{ Expect(err).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation, ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet", ErrorMessage: "empty packet",
@ -208,7 +234,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded} unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(unpackErr)) Expect(err).To(MatchError(unpackErr))
}) })
@ -230,7 +256,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
}) })
@ -240,31 +266,53 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
} }
hdr, hdrRaw := getHeader(extHdr) _, hdrRaw := getHeader(extHdr)
hdrRaw[0] |= 0x18 hdrRaw[0] |= 0x18
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
}) })
It("returns the decryption error, when unpacking a packet with wrong reserved bits fails", func() { It("returns the decryption error, when unpacking a packet with wrong reserved bits fails, for long headers", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Version: version,
},
PacketNumber: 0x1337,
PacketNumberLen: 2,
}
hdr, hdrRaw := getHeader(extHdr)
hdrRaw[0] |= 0x18
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
})
It("returns the decryption error, when unpacking a packet with wrong reserved bits fails, for short headers", func() {
extHdr := &wire.ExtendedHeader{ extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID}, Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
} }
hdr, hdrRaw := getHeader(extHdr) _, hdrRaw := getHeader(extHdr)
hdrRaw[0] |= 0x18 hdrRaw[0] |= 0x18
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
}) })
@ -307,7 +355,7 @@ var _ = Describe("Packet Unpacker", func() {
for i := 1; i <= 100; i++ { for i := 1; i <= 100; i++ {
data = append(data, uint8(i)) data = append(data, uint8(i))
} }
packet, err := unpacker.Unpack(hdr, time.Now(), data) packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331))) Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331)))
}) })

View file

@ -601,7 +601,7 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
// This makes sure that we won't send it for packets that were corrupted. // This makes sure that we won't send it for packets that were corrupted.
sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
data := p.data[:hdr.ParsedLen()+hdr.Length] data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackHeader(opener, hdr, data, hdr.Version) extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
if err != nil { if err != nil {
if s.config.Tracer != nil { if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)

View file

@ -758,7 +758,7 @@ var _ = Describe("Server", func() {
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
extHdr, err := unpackHeader(opener, replyHdr, b, origHdr.Version) extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())