diff --git a/connection.go b/connection.go index 27a7f986..41cab3dd 100644 --- a/connection.go +++ b/connection.go @@ -25,7 +25,8 @@ import ( ) type unpacker interface { - Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) + UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) + UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) } type streamGetter interface { @@ -362,7 +363,7 @@ var newConnection = func( s.perspective, s.version, ) - s.unpacker = newPacketUnpacker(cs, s.version) + s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) return s } @@ -474,7 +475,7 @@ var newClientConnection = func( s.clientHelloWritten = clientHelloWritten s.cryptoStreamHandler = cs s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) - s.unpacker = newPacketUnpacker(cs, s.version) + s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version) s.packer = newPacketPacker( srcConnID, s.connIDManager.Get, @@ -858,58 +859,113 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { if counter > 0 { p = p.Clone() p.data = data - } - hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen) - if err != nil { - if s.tracer != nil { - dropReason := logging.PacketDropHeaderParseError - if err == wire.ErrUnsupportedVersion { - dropReason = logging.PacketDropUnsupportedVersion + destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) } - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) + s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) + break } - s.logger.Debugf("error parsing packet: %s", err) + if destConnID != lastConnID { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) + } + s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) + break + } + } + + if wire.IsLongHeaderPacket(p.data[0]) { + hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen) + if err != nil { + if s.tracer != nil { + dropReason := logging.PacketDropHeaderParseError + if err == wire.ErrUnsupportedVersion { + dropReason = logging.PacketDropUnsupportedVersion + } + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) + } + s.logger.Debugf("error parsing packet: %s", err) + break + } + lastConnID = hdr.DestConnectionID + + if hdr.Version != s.version { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) + } + s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) + break + } + + if counter > 0 { + p.buffer.Split() + } + counter++ + + // only log if this actually a coalesced packet + if s.logger.Debug() && (counter > 1 || len(rest) > 0) { + s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) + } + + p.data = packetData + + if wasProcessed := s.handleLongHeaderPacket(p, hdr); wasProcessed { + processed = true + } + data = rest + } else { + if counter > 0 { + p.buffer.Split() + } + processed = s.handleShortHeaderPacket(p) break } - - if hdr.IsLongHeader && hdr.Version != s.version { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) - } - s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) - break - } - - if counter > 0 && hdr.DestConnectionID != lastConnID { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) - } - s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) - break - } - lastConnID = hdr.DestConnectionID - - if counter > 0 { - p.buffer.Split() - } - counter++ - - // only log if this actually a coalesced packet - if s.logger.Debug() && (counter > 1 || len(rest) > 0) { - s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) - } - p.data = packetData - if wasProcessed := s.handleSinglePacket(p, hdr); wasProcessed { - processed = true - } - data = rest } + p.buffer.MaybeRelease() return processed } -func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { +func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool { + var wasQueued bool + + defer func() { + // Put back the packet buffer if the packet wasn't queued for later decryption. + if !wasQueued { + p.buffer.Decrement() + } + }() + + hdr, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data) + if err != nil { + wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT) + return false + } + + if s.logger.Debug() { + s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", hdr.PacketNumber, p.Size(), hdr.DestConnectionID) + hdr.Log(s.logger) + } + + if s.receivedPacketHandler.IsPotentiallyDuplicate(hdr.PacketNumber, protocol.Encryption1RTT) { + s.logger.Debugf("Dropping (potentially) duplicate packet.") + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) + } + return false + } + + if err := s.handleUnpackedShortHeaderPacket(hdr, data, p.ecn, p.rcvTime, p.Size()); err != nil { + s.closeLocal(err) + return false + } + return true +} + +func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { var wasQueued bool defer func() { @@ -925,7 +981,7 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { + if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) } @@ -940,44 +996,9 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo return false } - packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data) + packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data) if err != nil { - switch err { - case handshake.ErrKeysDropped: - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropKeyUnavailable) - } - s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", hdr.PacketType(), p.Size()) - case handshake.ErrKeysNotYetAvailable: - // Sealer for this encryption level not yet available. - // Try again later. - wasQueued = true - s.tryQueueingUndecryptablePacket(p, logging.PacketTypeFromHeader(hdr)) - case wire.ErrInvalidReservedBits: - s.closeLocal(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: err.Error(), - }) - case handshake.ErrDecryptionFailed: - // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropPayloadDecryptError) - } - s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", hdr.PacketType(), p.Size(), err) - default: - var headerErr *headerParseError - if errors.As(err, &headerErr) { - // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropHeaderParseError) - } - s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", hdr.PacketType(), p.Size(), err) - } else { - // This is an error returned by the AEAD (other than ErrDecryptionFailed). - // For example, a PROTOCOL_VIOLATION due to key updates. - s.closeLocal(err) - } - } + wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr)) return false } @@ -1001,6 +1022,46 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo return true } +func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) { + switch err { + case handshake.ErrKeysDropped: + if s.tracer != nil { + s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable) + } + s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) + case handshake.ErrKeysNotYetAvailable: + // Sealer for this encryption level not yet available. + // Try again later. + s.tryQueueingUndecryptablePacket(p, pt) + return true + case wire.ErrInvalidReservedBits: + s.closeLocal(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: err.Error(), + }) + case handshake.ErrDecryptionFailed: + // This might be a packet injected by an attacker. Drop it. + if s.tracer != nil { + s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError) + } + s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) + default: + var headerErr *headerParseError + if errors.As(err, &headerErr) { + // This might be a packet injected by an attacker. Drop it. + if s.tracer != nil { + s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) + } else { + // This is an error returned by the AEAD (other than ErrDecryptionFailed). + // For example, a PROTOCOL_VIOLATION due to key updates. + s.closeLocal(err) + } + } + return false +} + func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { if s.tracer != nil { @@ -1167,15 +1228,51 @@ func (s *connection) handleUnpackedPacket( s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.keepAlivePingSent = false + var log func([]logging.Frame) + if s.tracer != nil { + log = func(frames []logging.Frame) { + s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames) + } + } + isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log) + if err != nil { + return err + } + return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) +} + +func (s *connection) handleUnpackedShortHeaderPacket(hdr *wire.ShortHeader, data []byte, ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount) error { + s.lastPacketReceivedTime = rcvTime + s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} + s.keepAlivePingSent = false + + var log func([]logging.Frame) + if s.tracer != nil { + log = func(frames []logging.Frame) { + s.tracer.ReceivedShortHeaderPacket(hdr, packetSize, frames) + } + } + isAckEliciting, err := s.handleFrames(data, hdr.DestConnectionID, protocol.Encryption1RTT, log) + if err != nil { + return err + } + return s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting) +} + +func (s *connection) handleFrames( + data []byte, + destConnID protocol.ConnectionID, + encLevel protocol.EncryptionLevel, + log func([]logging.Frame), +) (isAckEliciting bool, _ error) { // Only used for tracing. // If we're not tracing, this slice will always remain empty. var frames []wire.Frame - r := bytes.NewReader(packet.data) - var isAckEliciting bool + r := bytes.NewReader(data) for { - frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel) + frame, err := s.frameParser.ParseNext(r, encLevel) if err != nil { - return err + return false, err } if frame == nil { break @@ -1185,38 +1282,28 @@ func (s *connection) handleUnpackedPacket( } // Only process frames now if we're not logging. // If we're logging, we need to make sure that the packet_received event is logged first. - if s.tracer == nil { - if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { - return err + if log == nil { + if err := s.handleFrame(frame, encLevel, destConnID); err != nil { + return false, err } } else { frames = append(frames, frame) } } - if s.tracer != nil { + if log != nil { fs := make([]logging.Frame, len(frames)) for i, frame := range frames { fs[i] = logutils.ConvertFrame(frame) } - if packet.hdr.IsLongHeader { - s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, fs) - } else { - s.tracer.ReceivedShortHeaderPacket(&wire.ShortHeader{ - DestConnectionID: packet.hdr.DestConnectionID, - PacketNumber: packet.hdr.PacketNumber, - PacketNumberLen: packet.hdr.PacketNumberLen, - KeyPhase: packet.hdr.KeyPhase, - }, packetSize, fs) - } + log(fs) for _, frame := range frames { - if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { - return err + if err := s.handleFrame(frame, encLevel, destConnID); err != nil { + return false, err } } } - - return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) + return } func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { diff --git a/connection_test.go b/connection_test.go index 3846347e..af559b98 100644 --- a/connection_test.go +++ b/connection_test.go @@ -561,17 +561,13 @@ var _ = Describe("Connection", func() { PacketNumberLen: protocol.PacketNumberLen2, } Expect(hdr.Write(buf, conn.version)).To(Succeed()) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { + + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (*wire.ShortHeader, []byte, error) { b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version) Expect(err).ToNot(HaveOccurred()) - return &unpackedPacket{ - hdr: hdr, - data: b, - encryptionLevel: protocol.Encryption1RTT, - }, nil + return &wire.ShortHeader{PacketNumber: 3}, b, nil }) gomock.InOrder( - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()), tracer.EXPECT().Close(), @@ -727,7 +723,13 @@ var _ = Describe("Connection", func() { It("informs the ReceivedPacketHandler about non-ack-eliciting packets", func() { hdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: srcConnID, + Version: protocol.Version1, + Length: 1, + }, PacketNumber: 0x37, PacketNumberLen: protocol.PacketNumberLen1, } @@ -736,7 +738,7 @@ var _ = Describe("Connection", func() { packet := getPacket(hdr, nil) packet.ecn = protocol.ECNCE rcvTime := time.Now().Add(-10 * time.Second) - unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.EncryptionInitial, hdr: &unpackedHdr, data: []byte{0}, // one PADDING frame @@ -749,7 +751,7 @@ var _ = Describe("Connection", func() { conn.receivedPacketHandler = rph packet.rcvTime = rcvTime tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 1, DestConnectionID: srcConnID}, protocol.ByteCount(len(packet.data)), []logging.Frame{}) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), []logging.Frame{}) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) @@ -759,18 +761,12 @@ var _ = Describe("Connection", func() { PacketNumber: 0x37, PacketNumberLen: protocol.PacketNumberLen1, } - unpackedHdr := *hdr - unpackedHdr.PacketNumber = 0x1337 rcvTime := time.Now().Add(-10 * time.Second) b, err := (&wire.PingFrame{}).Append(nil, conn.version) Expect(err).ToNot(HaveOccurred()) packet := getPacket(hdr, nil) packet.ecn = protocol.ECT1 - unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ - encryptionLevel: protocol.Encryption1RTT, - hdr: &unpackedHdr, - data: b, - }, nil) + unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, b, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) gomock.InOrder( rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT), @@ -778,8 +774,7 @@ var _ = Describe("Connection", func() { ) conn.receivedPacketHandler = rph packet.rcvTime = rcvTime - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 1, DestConnectionID: srcConnID}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) + tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) @@ -790,15 +785,7 @@ var _ = Describe("Connection", func() { PacketNumberLen: protocol.PacketNumberLen1, } packet := getPacket(hdr, nil) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { - h := *hdr - h.PacketNumber = 0x1337 - return &unpackedPacket{ - encryptionLevel: protocol.Encryption1RTT, - hdr: &h, - data: []byte("foobar"), - }, nil - }) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, []byte("foobar"), nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) conn.receivedPacketHandler = rph @@ -807,7 +794,7 @@ var _ = Describe("Connection", func() { }) It("drops a packet when unpacking fails", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) @@ -842,18 +829,10 @@ var _ = Describe("Connection", func() { It("processes multiple received packets before sending one", func() { conn.creationTime = time.Now() var pn protocol.PacketNumber - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { pn++ - return &unpackedPacket{ - data: []byte{0}, // PADDING frame - encryptionLevel: protocol.Encryption1RTT, - hdr: &wire.ExtendedHeader{ - Header: *hdr, - PacketNumber: pn, - }, - }, nil + return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil }).Times(3) - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { }).Times(3) packer.EXPECT().PackCoalescedPacket() // only expect a single call @@ -889,18 +868,10 @@ var _ = Describe("Connection", func() { conn.handshakeComplete = false conn.creationTime = time.Now() var pn protocol.PacketNumber - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { pn++ - return &unpackedPacket{ - data: []byte{0}, // PADDING frame - encryptionLevel: protocol.Encryption1RTT, - hdr: &wire.ExtendedHeader{ - Header: *hdr, - PacketNumber: pn, - }, - }, nil + return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil }).Times(3) - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { }).Times(3) packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call @@ -933,7 +904,7 @@ var _ = Describe("Connection", func() { }) It("closes the connection when unpacking fails because the reserved bits were incorrect", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, wire.ErrInvalidReservedBits) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) @@ -961,7 +932,7 @@ var _ = Describe("Connection", func() { It("ignores packets when unpacking the header fails", func() { testErr := &headerParseError{errors.New("test error")} - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, testErr) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() runErr := make(chan error) @@ -987,7 +958,7 @@ var _ = Describe("Connection", func() { }) It("closes the connection when unpacking fails because of an error other than a decryption error", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) @@ -1041,7 +1012,7 @@ var _ = Describe("Connection", func() { Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) // Send one packet, which might change the connection ID. // only EXPECT one call to the unpacker - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.Encryption1RTT, hdr: hdr1, data: []byte{0}, // one PADDING frame @@ -1070,7 +1041,7 @@ var _ = Describe("Connection", func() { PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 1, } - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable) + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable) packet := getPacket(hdr, nil) tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake) Expect(conn.handlePacketImpl(packet)).To(BeFalse()) @@ -1079,17 +1050,13 @@ var _ = Describe("Connection", func() { Context("updating the remote address", func() { It("doesn't support connection migration", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - encryptionLevel: protocol.Encryption1RTT, - hdr: &wire.ExtendedHeader{}, - data: []byte{0}, // one PADDING frame - }, nil) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{}, + []byte{0} /* one PADDING frame */, nil) packet := getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, }, nil) packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) @@ -1120,7 +1087,7 @@ var _ = Describe("Connection", func() { It("cuts packets to the right length", func() { hdrLen, packet := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen + 456 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1134,7 +1101,7 @@ var _ = Describe("Connection", func() { It("handles coalesced packets", func() { hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1146,7 +1113,7 @@ var _ = Describe("Connection", func() { }, nil }) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1170,8 +1137,8 @@ var _ = Describe("Connection", func() { hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) gomock.InOrder( - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1195,7 +1162,7 @@ var _ = Describe("Connection", func() { wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) Expect(srcConnID).ToNot(Equal(wrongConnID)) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1204,7 +1171,7 @@ var _ = Describe("Connection", func() { }, nil }) _, packet2 := getPacketWithLength(wrongConnID, 123) - // don't EXPECT any more calls to unpacker.Unpack() + // don't EXPECT any more calls to unpacker.UnpackLongHeader() gomock.InOrder( tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), @@ -2456,7 +2423,7 @@ var _ = Describe("Client Connection", func() { It("changes the connection ID when receiving the first packet from the server", func() { unpacker := NewMockUnpacker(mockCtrl) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { return &unpackedPacket{ encryptionLevel: protocol.Encryption1RTT, hdr: &wire.ExtendedHeader{Header: *hdr}, @@ -2505,7 +2472,7 @@ var _ = Describe("Client Connection", func() { }) Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))) // now receive a packet with the original source connection ID - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { return &unpackedPacket{ hdr: &wire.ExtendedHeader{Header: *hdr}, data: []byte{0}, @@ -2519,7 +2486,7 @@ var _ = Describe("Client Connection", func() { SrcConnectionID: destConnID, } tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(conn.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) + Expect(conn.handleLongHeaderPacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) }) It("handles HANDSHAKE_DONE frames", func() { @@ -2946,7 +2913,7 @@ var _ = Describe("Client Connection", func() { Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID)) // Send one packet, which might change the connection ID. // only EXPECT one call to the unpacker - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.EncryptionInitial, hdr: hdr1, data: []byte{0}, // one PADDING frame diff --git a/mock_unpacker_test.go b/mock_unpacker_test.go index 22da001b..1410f8ec 100644 --- a/mock_unpacker_test.go +++ b/mock_unpacker_test.go @@ -35,17 +35,33 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder { return m.recorder } -// Unpack mocks base method. -func (m *MockUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { +// UnpackLongHeader mocks base method. +func (m *MockUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Unpack", hdr, rcvTime, data) + ret := m.ctrl.Call(m, "UnpackLongHeader", hdr, rcvTime, data) ret0, _ := ret[0].(*unpackedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } -// Unpack indicates an expected call of Unpack. -func (mr *MockUnpackerMockRecorder) Unpack(hdr, rcvTime, data interface{}) *gomock.Call { +// UnpackLongHeader indicates an expected call of UnpackLongHeader. +func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, rcvTime, data interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), hdr, rcvTime, data) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), hdr, rcvTime, data) +} + +// UnpackShortHeader mocks base method. +func (m *MockUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnpackShortHeader", rcvTime, data) + ret0, _ := ret[0].(*wire.ShortHeader) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// UnpackShortHeader indicates an expected call of UnpackShortHeader. +func (mr *MockUnpackerMockRecorder) UnpackShortHeader(rcvTime, data interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackShortHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackShortHeader), rcvTime, data) } diff --git a/packet_unpacker.go b/packet_unpacker.go index a74b3047..688724f8 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -37,22 +37,25 @@ type unpackedPacket struct { type packetUnpacker struct { cs handshake.CryptoSetup - version protocol.VersionNumber + shortHdrConnIDLen int + version protocol.VersionNumber } var _ unpacker = &packetUnpacker{} -func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker { +func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int, version protocol.VersionNumber) unpacker { return &packetUnpacker{ - cs: cs, - version: version, + cs: cs, + shortHdrConnIDLen: shortHdrConnIDLen, + version: version, } } +// UnpackLongHeader unpacks a Long Header packet. // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. // If any other error occurred when parsing the header, the error is of type headerParseError. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD. -func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { +func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { var encLevel protocol.EncryptionLevel var extHdr *wire.ExtendedHeader var decrypted []byte @@ -89,18 +92,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte return nil, err } default: - if hdr.IsLongHeader { - return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) - } - encLevel = protocol.Encryption1RTT - opener, err := u.cs.Get1RTTOpener() - if err != nil { - return nil, err - } - extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, rcvTime, data) - if err != nil { - return nil, err - } + return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) } if len(decrypted) == 0 { @@ -117,8 +109,26 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte }, nil } +func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { + opener, err := u.cs.Get1RTTOpener() + if err != nil { + return nil, nil, err + } + hdr, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data) + if err != nil { + return nil, nil, err + } + if len(decrypted) == 0 { + return nil, nil, &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "empty packet", + } + } + return hdr, decrypted, nil +} + func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { - extHdr, parseErr := u.unpackHeader(opener, hdr, data) + extHdr, parseErr := u.unpackLongHeader(opener, hdr, data) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. @@ -137,41 +147,58 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene return extHdr, decrypted, nil } -func (u *packetUnpacker) unpackShortHeaderPacket( - opener handshake.ShortHeaderOpener, - hdr *wire.Header, - rcvTime time.Time, - data []byte, -) (*wire.ExtendedHeader, []byte, error) { - extHdr, parseErr := u.unpackHeader(opener, hdr, data) +func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { + hdr, parseErr := u.unpackShortHeader(opener, data) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { - return nil, nil, parseErr + return nil, nil, &headerParseError{parseErr} } - extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) - extHdrLen := extHdr.ParsedLen() - decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) + hdr.PacketNumber = opener.DecodePacketNumber(hdr.PacketNumber, hdr.PacketNumberLen) + l := hdr.Len() + decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, hdr.PacketNumber, hdr.KeyPhase, data[:l]) if err != nil { return nil, nil, err } - if parseErr != nil { - return nil, nil, parseErr + return hdr, decrypted, parseErr +} + +func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (*wire.ShortHeader, error) { + hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen + if len(data) < hdrLen+4+16 { + return nil, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen) } - return extHdr, decrypted, nil + origPNBytes := make([]byte, 4) + copy(origPNBytes, data[hdrLen:hdrLen+4]) + // 2. decrypt the header, assuming a 4 byte packet number + hd.DecryptHeader( + data[hdrLen+4:hdrLen+4+16], + &data[0], + data[hdrLen:hdrLen+4], + ) + // 3. parse the header (and learn the actual length of the packet number) + hdr, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen) + if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { + return nil, parseErr + } + // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier + if hdr.PacketNumberLen != protocol.PacketNumberLen4 { + copy(data[hdrLen+int(hdr.PacketNumberLen):hdrLen+4], origPNBytes[int(hdr.PacketNumberLen):]) + } + return hdr, parseErr } // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. -func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { - extHdr, err := unpackHeader(hd, hdr, data, u.version) +func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { + extHdr, err := unpackLongHeader(hd, hdr, data, u.version) if err != nil && err != wire.ErrInvalidReservedBits { return nil, &headerParseError{err: err} } return extHdr, err } -func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) { +func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) { r := bytes.NewReader(data) hdrLen := hdr.ParsedLen() diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 239311be..3857f70a 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -41,7 +41,7 @@ var _ = Describe("Packet Unpacker", func() { BeforeEach(func() { cs = mocks.NewMockCryptoSetup(mockCtrl) - unpacker = newPacketUnpacker(cs, version).(*packetUnpacker) + unpacker = newPacketUnpacker(cs, 4, version).(*packetUnpacker) }) It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() { @@ -59,7 +59,7 @@ var _ = Describe("Packet Unpacker", func() { data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) - _, err := unpacker.Unpack(hdr, time.Now(), data) + _, err := unpacker.UnpackLongHeader(hdr, time.Now(), data) Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) var headerErr *headerParseError Expect(errors.As(err, &headerErr)).To(BeTrue()) @@ -67,18 +67,17 @@ var _ = Describe("Packet Unpacker", func() { }) It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() { - extHdr := &wire.ExtendedHeader{ + _, hdrRaw := getHeader(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: connID}, PacketNumber: 1337, PacketNumberLen: protocol.PacketNumberLen2, - } - hdr, hdrRaw := getHeader(extHdr) + }) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) opener := mocks.NewMockShortHeaderOpener(mockCtrl) cs.EXPECT().Get1RTTOpener().Return(opener, nil) - _, err := unpacker.Unpack(hdr, time.Now(), data) + _, _, err := unpacker.UnpackShortHeader(time.Now(), data) Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) - Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) + Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19")) }) It("opens Initial packets", func() { @@ -101,7 +100,7 @@ var _ = Describe("Packet Unpacker", func() { opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)), opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil), ) - packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).ToNot(HaveOccurred()) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) Expect(packet.data).To(Equal([]byte("decrypted"))) @@ -127,7 +126,7 @@ var _ = Describe("Packet Unpacker", func() { opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)), opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil), ) - packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).ToNot(HaveOccurred()) Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT)) Expect(packet.data).To(Equal([]byte("decrypted"))) @@ -140,7 +139,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 99, PacketNumberLen: protocol.PacketNumberLen4, } - hdr, hdrRaw := getHeader(extHdr) + _, hdrRaw := getHeader(extHdr) opener := mocks.NewMockShortHeaderOpener(mockCtrl) now := time.Now() gomock.InOrder( @@ -149,10 +148,11 @@ var _ = Describe("Packet Unpacker", func() { opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)), opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil), ) - packet, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...)) + hdr, data, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...)) Expect(err).ToNot(HaveOccurred()) - Expect(packet.encryptionLevel).To(Equal(protocol.Encryption1RTT)) - Expect(packet.data).To(Equal([]byte("decrypted"))) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(321))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(data).To(Equal([]byte("decrypted"))) }) It("returns the error when getting the opener fails", func() { @@ -161,19 +161,45 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 0x1337, PacketNumberLen: 2, } - hdr, hdrRaw := getHeader(extHdr) + _, hdrRaw := getHeader(extHdr) cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable)) }) - It("errors on empty packets", func() { + It("errors on empty packets, for long header packets", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: connID, + Version: Version1, + }, + KeyPhase: protocol.KeyPhaseOne, + PacketNumberLen: protocol.PacketNumberLen4, + } + hdr, hdrRaw := getHeader(extHdr) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) + gomock.InOrder( + cs.EXPECT().GetHandshakeOpener().Return(opener, nil), + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)), + opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte(""), nil), + ) + _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "empty packet", + })) + }) + + It("errors on empty packets, for short header packets", func() { extHdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: connID}, KeyPhase: protocol.KeyPhaseOne, PacketNumberLen: protocol.PacketNumberLen4, } - hdr, hdrRaw := getHeader(extHdr) + _, hdrRaw := getHeader(extHdr) opener := mocks.NewMockShortHeaderOpener(mockCtrl) now := time.Now() gomock.InOrder( @@ -182,7 +208,7 @@ var _ = Describe("Packet Unpacker", func() { opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)), opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte(""), nil), ) - _, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...)) + _, _, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...)) Expect(err).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "empty packet", @@ -208,7 +234,7 @@ var _ = Describe("Packet Unpacker", func() { opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded} opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(unpackErr)) }) @@ -230,7 +256,7 @@ var _ = Describe("Packet Unpacker", func() { cs.EXPECT().GetHandshakeOpener().Return(opener, nil) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) }) @@ -240,31 +266,53 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 0x1337, PacketNumberLen: 2, } - hdr, hdrRaw := getHeader(extHdr) + _, hdrRaw := getHeader(extHdr) hdrRaw[0] |= 0x18 opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) cs.EXPECT().Get1RTTOpener().Return(opener, nil) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) }) - It("returns the decryption error, when unpacking a packet with wrong reserved bits fails", func() { + It("returns the decryption error, when unpacking a packet with wrong reserved bits fails, for long headers", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 0x1337, + PacketNumberLen: 2, + } + hdr, hdrRaw := getHeader(extHdr) + hdrRaw[0] |= 0x18 + opener := mocks.NewMockLongHeaderOpener(mockCtrl) + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) + cs.EXPECT().GetHandshakeOpener().Return(opener, nil) + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) + _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) + Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) + }) + + It("returns the decryption error, when unpacking a packet with wrong reserved bits fails, for short headers", func() { extHdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: connID}, PacketNumber: 0x1337, PacketNumberLen: 2, } - hdr, hdrRaw := getHeader(extHdr) + _, hdrRaw := getHeader(extHdr) hdrRaw[0] |= 0x18 opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) cs.EXPECT().Get1RTTOpener().Return(opener, nil) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) }) @@ -307,7 +355,7 @@ var _ = Describe("Packet Unpacker", func() { for i := 1; i <= 100; i++ { data = append(data, uint8(i)) } - packet, err := unpacker.Unpack(hdr, time.Now(), data) + packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data) Expect(err).ToNot(HaveOccurred()) Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331))) }) diff --git a/server.go b/server.go index ae29ff9b..16d4d818 100644 --- a/server.go +++ b/server.go @@ -601,7 +601,7 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) // This makes sure that we won't send it for packets that were corrupted. sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) data := p.data[:hdr.ParsedLen()+hdr.Length] - extHdr, err := unpackHeader(opener, hdr, data, hdr.Version) + extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) if err != nil { if s.config.Tracer != nil { s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) diff --git a/server_test.go b/server_test.go index 1cc68a97..66db72f7 100644 --- a/server_test.go +++ b/server_test.go @@ -758,7 +758,7 @@ var _ = Describe("Server", func() { Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) - extHdr, err := unpackHeader(opener, replyHdr, b, origHdr.Version) + extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version) Expect(err).ToNot(HaveOccurred()) data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) Expect(err).ToNot(HaveOccurred())