mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
remove the wire.ShortHeader in favor of more return values (#3535)
This commit is contained in:
parent
5cd5341545
commit
656f3d2d7d
12 changed files with 125 additions and 131 deletions
|
@ -26,7 +26,7 @@ import (
|
|||
|
||||
type unpacker interface {
|
||||
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
|
||||
UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error)
|
||||
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
|
||||
}
|
||||
|
||||
type streamGetter interface {
|
||||
|
@ -856,11 +856,13 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
|
|||
data := rp.data
|
||||
p := rp
|
||||
for len(data) > 0 {
|
||||
var destConnID protocol.ConnectionID
|
||||
if counter > 0 {
|
||||
p = p.Clone()
|
||||
p.data = data
|
||||
|
||||
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
|
||||
var err error
|
||||
destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen)
|
||||
if err != nil {
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
|
||||
|
@ -920,7 +922,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
|
|||
if counter > 0 {
|
||||
p.buffer.Split()
|
||||
}
|
||||
processed = s.handleShortHeaderPacket(p)
|
||||
processed = s.handleShortHeaderPacket(p, destConnID)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -929,7 +931,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
|
|||
return processed
|
||||
}
|
||||
|
||||
func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
|
||||
func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool {
|
||||
var wasQueued bool
|
||||
|
||||
defer func() {
|
||||
|
@ -939,18 +941,18 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
|
|||
}
|
||||
}()
|
||||
|
||||
hdr, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
|
||||
pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
|
||||
if err != nil {
|
||||
wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.logger.Debug() {
|
||||
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", hdr.PacketNumber, p.Size(), hdr.DestConnectionID)
|
||||
hdr.Log(s.logger)
|
||||
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", pn, p.Size(), destConnID)
|
||||
wire.LogShortHeader(s.logger, destConnID, pn, pnLen, keyPhase)
|
||||
}
|
||||
|
||||
if s.receivedPacketHandler.IsPotentiallyDuplicate(hdr.PacketNumber, protocol.Encryption1RTT) {
|
||||
if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) {
|
||||
s.logger.Debugf("Dropping (potentially) duplicate packet.")
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate)
|
||||
|
@ -958,7 +960,22 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
if err := s.handleUnpackedShortHeaderPacket(hdr, data, p.ecn, p.rcvTime, p.Size()); err != nil {
|
||||
var log func([]logging.Frame)
|
||||
if s.tracer != nil {
|
||||
log = func(frames []logging.Frame) {
|
||||
s.tracer.ReceivedShortHeaderPacket(
|
||||
&logging.ShortHeader{
|
||||
DestConnectionID: destConnID,
|
||||
PacketNumber: pn,
|
||||
PacketNumberLen: pnLen,
|
||||
KeyPhase: keyPhase,
|
||||
},
|
||||
p.Size(),
|
||||
frames,
|
||||
)
|
||||
}
|
||||
}
|
||||
if err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log); err != nil {
|
||||
s.closeLocal(err)
|
||||
return false
|
||||
}
|
||||
|
@ -1241,22 +1258,23 @@ func (s *connection) handleUnpackedPacket(
|
|||
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
|
||||
}
|
||||
|
||||
func (s *connection) handleUnpackedShortHeaderPacket(hdr *wire.ShortHeader, data []byte, ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount) error {
|
||||
func (s *connection) handleUnpackedShortHeaderPacket(
|
||||
destConnID protocol.ConnectionID,
|
||||
pn protocol.PacketNumber,
|
||||
data []byte,
|
||||
ecn protocol.ECN,
|
||||
rcvTime time.Time,
|
||||
log func([]logging.Frame),
|
||||
) error {
|
||||
s.lastPacketReceivedTime = rcvTime
|
||||
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
|
||||
s.keepAlivePingSent = false
|
||||
|
||||
var log func([]logging.Frame)
|
||||
if s.tracer != nil {
|
||||
log = func(frames []logging.Frame) {
|
||||
s.tracer.ReceivedShortHeaderPacket(hdr, packetSize, frames)
|
||||
}
|
||||
}
|
||||
isAckEliciting, err := s.handleFrames(data, hdr.DestConnectionID, protocol.Encryption1RTT, log)
|
||||
isAckEliciting, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting)
|
||||
return s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting)
|
||||
}
|
||||
|
||||
func (s *connection) handleFrames(
|
||||
|
|
|
@ -562,10 +562,10 @@ var _ = Describe("Connection", func() {
|
|||
}
|
||||
Expect(hdr.Write(buf, conn.version)).To(Succeed())
|
||||
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (*wire.ShortHeader, []byte, error) {
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
||||
b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return &wire.ShortHeader{PacketNumber: 3}, b, nil
|
||||
return 3, protocol.PacketNumberLen2, protocol.KeyPhaseOne, b, nil
|
||||
})
|
||||
gomock.InOrder(
|
||||
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()),
|
||||
|
@ -766,7 +766,7 @@ var _ = Describe("Connection", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
packet := getPacket(hdr, nil)
|
||||
packet.ecn = protocol.ECT1
|
||||
unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, b, nil)
|
||||
unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, b, nil)
|
||||
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
||||
gomock.InOrder(
|
||||
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT),
|
||||
|
@ -774,7 +774,7 @@ var _ = Describe("Connection", func() {
|
|||
)
|
||||
conn.receivedPacketHandler = rph
|
||||
packet.rcvTime = rcvTime
|
||||
tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}})
|
||||
tracer.EXPECT().ReceivedShortHeaderPacket(&logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}})
|
||||
Expect(conn.handlePacketImpl(packet)).To(BeTrue())
|
||||
})
|
||||
|
||||
|
@ -785,7 +785,7 @@ var _ = Describe("Connection", func() {
|
|||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
packet := getPacket(hdr, nil)
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, []byte("foobar"), nil)
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseOne, []byte("foobar"), nil)
|
||||
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
||||
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true)
|
||||
conn.receivedPacketHandler = rph
|
||||
|
@ -829,11 +829,11 @@ var _ = Describe("Connection", func() {
|
|||
It("processes multiple received packets before sending one", func() {
|
||||
conn.creationTime = time.Now()
|
||||
var pn protocol.PacketNumber
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
||||
pn++
|
||||
return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil
|
||||
return pn, protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil
|
||||
}).Times(3)
|
||||
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
|
||||
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
|
||||
}).Times(3)
|
||||
packer.EXPECT().PackCoalescedPacket() // only expect a single call
|
||||
|
||||
|
@ -868,11 +868,11 @@ var _ = Describe("Connection", func() {
|
|||
conn.handshakeComplete = false
|
||||
conn.creationTime = time.Now()
|
||||
var pn protocol.PacketNumber
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
||||
pn++
|
||||
return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil
|
||||
return pn, protocol.PacketNumberLen4, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil
|
||||
}).Times(3)
|
||||
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
|
||||
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
|
||||
}).Times(3)
|
||||
packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call
|
||||
|
||||
|
@ -904,7 +904,7 @@ var _ = Describe("Connection", func() {
|
|||
})
|
||||
|
||||
It("closes the connection when unpacking fails because the reserved bits were incorrect", func() {
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, wire.ErrInvalidReservedBits)
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, wire.ErrInvalidReservedBits)
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
cryptoSetup.EXPECT().Close()
|
||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
|
||||
|
@ -932,7 +932,7 @@ var _ = Describe("Connection", func() {
|
|||
|
||||
It("ignores packets when unpacking the header fails", func() {
|
||||
testErr := &headerParseError{errors.New("test error")}
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, testErr)
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, testErr)
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
cryptoSetup.EXPECT().Close()
|
||||
runErr := make(chan error)
|
||||
|
@ -958,7 +958,7 @@ var _ = Describe("Connection", func() {
|
|||
})
|
||||
|
||||
It("closes the connection when unpacking fails because of an error other than a decryption error", func() {
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
cryptoSetup.EXPECT().Close()
|
||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
|
||||
|
@ -1050,8 +1050,7 @@ var _ = Describe("Connection", func() {
|
|||
|
||||
Context("updating the remote address", func() {
|
||||
It("doesn't support connection migration", func() {
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{},
|
||||
[]byte{0} /* one PADDING frame */, nil)
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* one PADDING frame */, nil)
|
||||
packet := getPacket(&wire.ExtendedHeader{
|
||||
Header: wire.Header{DestConnectionID: srcConnID},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
|
|
|
@ -208,7 +208,7 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom
|
|||
}
|
||||
|
||||
// ReceivedShortHeaderPacket mocks base method.
|
||||
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *wire.ShortHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) {
|
||||
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2)
|
||||
}
|
||||
|
|
|
@ -9,28 +9,20 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type ShortHeader struct {
|
||||
DestConnectionID protocol.ConnectionID
|
||||
PacketNumber protocol.PacketNumber
|
||||
PacketNumberLen protocol.PacketNumberLen
|
||||
KeyPhase protocol.KeyPhaseBit
|
||||
}
|
||||
|
||||
func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) {
|
||||
func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.PacketNumber, _ protocol.PacketNumberLen, _ protocol.KeyPhaseBit, _ error) {
|
||||
if len(data) == 0 {
|
||||
return nil, io.EOF
|
||||
return 0, 0, 0, 0, io.EOF
|
||||
}
|
||||
if data[0]&0x80 > 0 {
|
||||
return nil, errors.New("not a short header packet")
|
||||
return 0, 0, 0, 0, errors.New("not a short header packet")
|
||||
}
|
||||
if data[0]&0x40 == 0 {
|
||||
return nil, errors.New("not a QUIC packet")
|
||||
return 0, 0, 0, 0, errors.New("not a QUIC packet")
|
||||
}
|
||||
pnLen := protocol.PacketNumberLen(data[0]&0b11) + 1
|
||||
if len(data) < 1+int(pnLen)+connIDLen {
|
||||
return nil, io.EOF
|
||||
return 0, 0, 0, 0, io.EOF
|
||||
}
|
||||
destConnID := protocol.ParseConnectionID(data[1 : 1+connIDLen])
|
||||
|
||||
pos := 1 + connIDLen
|
||||
var pn protocol.PacketNumber
|
||||
|
@ -44,7 +36,7 @@ func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) {
|
|||
case protocol.PacketNumberLen4:
|
||||
pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[pos : pos+4]))
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid packet number length: %d", pnLen)
|
||||
return 0, 0, 0, 0, fmt.Errorf("invalid packet number length: %d", pnLen)
|
||||
}
|
||||
kp := protocol.KeyPhaseZero
|
||||
if data[0]&0b100 > 0 {
|
||||
|
@ -55,19 +47,9 @@ func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) {
|
|||
if data[0]&0x18 != 0 {
|
||||
err = ErrInvalidReservedBits
|
||||
}
|
||||
return &ShortHeader{
|
||||
DestConnectionID: destConnID,
|
||||
PacketNumber: pn,
|
||||
PacketNumberLen: pnLen,
|
||||
KeyPhase: kp,
|
||||
}, err
|
||||
return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err
|
||||
}
|
||||
|
||||
func (h *ShortHeader) Len() protocol.ByteCount {
|
||||
return 1 + protocol.ByteCount(h.DestConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen)
|
||||
}
|
||||
|
||||
// Log logs the Header
|
||||
func (h *ShortHeader) Log(logger utils.Logger) {
|
||||
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
|
||||
func LogShortHeader(logger utils.Logger, dest protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) {
|
||||
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", dest, pn, pnLen, kp)
|
||||
}
|
||||
|
|
|
@ -21,12 +21,12 @@ var _ = Describe("Short Header", func() {
|
|||
0xde, 0xad, 0xbe, 0xef,
|
||||
0x13, 0x37, 0x99,
|
||||
}
|
||||
hdr, err := ParseShortHeader(data, 4)
|
||||
l, pn, pnLen, kp, err := ParseShortHeader(data, 4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})))
|
||||
Expect(hdr.KeyPhase).To(Equal(protocol.KeyPhaseOne))
|
||||
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x133799)))
|
||||
Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3))
|
||||
Expect(l).To(Equal(len(data)))
|
||||
Expect(kp).To(Equal(protocol.KeyPhaseOne))
|
||||
Expect(pn).To(Equal(protocol.PacketNumber(0x133799)))
|
||||
Expect(pnLen).To(Equal(protocol.PacketNumberLen3))
|
||||
})
|
||||
|
||||
It("errors when the QUIC bit is not set", func() {
|
||||
|
@ -35,7 +35,7 @@ var _ = Describe("Short Header", func() {
|
|||
0xde, 0xad, 0xbe, 0xef,
|
||||
0x13, 0x37,
|
||||
}
|
||||
_, err := ParseShortHeader(data, 4)
|
||||
_, _, _, _, err := ParseShortHeader(data, 4)
|
||||
Expect(err).To(MatchError("not a QUIC packet"))
|
||||
})
|
||||
|
||||
|
@ -45,14 +45,13 @@ var _ = Describe("Short Header", func() {
|
|||
0xde, 0xad, 0xbe, 0xef,
|
||||
0x13, 0x37,
|
||||
}
|
||||
hdr, err := ParseShortHeader(data, 4)
|
||||
_, pn, _, _, err := ParseShortHeader(data, 4)
|
||||
Expect(err).To(MatchError(ErrInvalidReservedBits))
|
||||
Expect(hdr).ToNot(BeNil())
|
||||
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337)))
|
||||
Expect(pn).To(Equal(protocol.PacketNumber(0x1337)))
|
||||
})
|
||||
|
||||
It("errors when passed a long header packet", func() {
|
||||
_, err := ParseShortHeader([]byte{0x80}, 4)
|
||||
_, _, _, _, err := ParseShortHeader([]byte{0x80}, 4)
|
||||
Expect(err).To(MatchError("not a short header packet"))
|
||||
})
|
||||
|
||||
|
@ -62,10 +61,10 @@ var _ = Describe("Short Header", func() {
|
|||
0xde, 0xad, 0xbe, 0xef,
|
||||
0x13, 0x37, 0x99,
|
||||
}
|
||||
_, err := ParseShortHeader(data, 4)
|
||||
_, _, _, _, err := ParseShortHeader(data, 4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := range data {
|
||||
_, err := ParseShortHeader(data[:i], 4)
|
||||
_, _, _, _, err := ParseShortHeader(data[:i], 4)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
}
|
||||
})
|
||||
|
@ -89,22 +88,9 @@ var _ = Describe("Short Header", func() {
|
|||
})
|
||||
|
||||
It("logs Short Headers containing a connection ID", func() {
|
||||
(&ShortHeader{
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}),
|
||||
KeyPhase: protocol.KeyPhaseOne,
|
||||
PacketNumber: 1337,
|
||||
PacketNumberLen: 4,
|
||||
}).Log(logger)
|
||||
connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})
|
||||
LogShortHeader(logger, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
|
||||
Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}"))
|
||||
})
|
||||
})
|
||||
|
||||
It("determines the length", func() {
|
||||
Expect((&ShortHeader{
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xaf}),
|
||||
PacketNumber: 0x1337,
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
KeyPhase: protocol.KeyPhaseOne,
|
||||
}).Len()).To(Equal(protocol.ByteCount(1 + 2 + 3)))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -46,8 +46,6 @@ type (
|
|||
Header = wire.Header
|
||||
// The ExtendedHeader is the QUIC Long Header packet header, after removing header protection.
|
||||
ExtendedHeader = wire.ExtendedHeader
|
||||
// The ShortHeader is the QUIC Short Header packet header, after removing header protection.
|
||||
ShortHeader = wire.ShortHeader
|
||||
// The TransportParameters are QUIC transport parameters.
|
||||
TransportParameters = wire.TransportParameters
|
||||
// The PreferredAddress is the preferred address sent in the transport parameters.
|
||||
|
@ -94,6 +92,14 @@ const (
|
|||
StreamTypeBidi = protocol.StreamTypeBidi
|
||||
)
|
||||
|
||||
// The ShortHeader is the QUIC Short Header packet header, after removing header protection.
|
||||
type ShortHeader struct {
|
||||
DestConnectionID ConnectionID
|
||||
PacketNumber PacketNumber
|
||||
PacketNumberLen protocol.PacketNumberLen
|
||||
KeyPhase KeyPhaseBit
|
||||
}
|
||||
|
||||
// A Tracer traces events.
|
||||
type Tracer interface {
|
||||
// TracerForConnection requests a new tracer for a connection.
|
||||
|
|
|
@ -207,7 +207,7 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom
|
|||
}
|
||||
|
||||
// ReceivedShortHeaderPacket mocks base method.
|
||||
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *wire.ShortHeader, arg1 protocol.ByteCount, arg2 []Frame) {
|
||||
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 []Frame) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2)
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
wire "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
|
@ -51,13 +52,15 @@ func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, rcvTime, data interfac
|
|||
}
|
||||
|
||||
// UnpackShortHeader mocks base method.
|
||||
func (m *MockUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
|
||||
func (m *MockUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnpackShortHeader", rcvTime, data)
|
||||
ret0, _ := ret[0].(*wire.ShortHeader)
|
||||
ret1, _ := ret[1].([]byte)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
ret0, _ := ret[0].(protocol.PacketNumber)
|
||||
ret1, _ := ret[1].(protocol.PacketNumberLen)
|
||||
ret2, _ := ret[2].(protocol.KeyPhaseBit)
|
||||
ret3, _ := ret[3].([]byte)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// UnpackShortHeader indicates an expected call of UnpackShortHeader.
|
||||
|
|
|
@ -109,22 +109,22 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
|
||||
func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
||||
opener, err := u.cs.Get1RTTOpener()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, 0, 0, nil, err
|
||||
}
|
||||
hdr, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
|
||||
pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, 0, 0, nil, err
|
||||
}
|
||||
if len(decrypted) == 0 {
|
||||
return nil, nil, &qerr.TransportError{
|
||||
return 0, 0, 0, nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "empty packet",
|
||||
}
|
||||
}
|
||||
return hdr, decrypted, nil
|
||||
return pn, pnLen, kp, decrypted, nil
|
||||
}
|
||||
|
||||
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
|
||||
|
@ -147,27 +147,26 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene
|
|||
return extHdr, decrypted, nil
|
||||
}
|
||||
|
||||
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
|
||||
hdr, parseErr := u.unpackShortHeader(opener, data)
|
||||
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
||||
l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data)
|
||||
// If the reserved bits are set incorrectly, we still need to continue unpacking.
|
||||
// This avoids a timing side-channel, which otherwise might allow an attacker
|
||||
// to gain information about the header encryption.
|
||||
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
|
||||
return nil, nil, &headerParseError{parseErr}
|
||||
return 0, 0, 0, nil, &headerParseError{parseErr}
|
||||
}
|
||||
hdr.PacketNumber = opener.DecodePacketNumber(hdr.PacketNumber, hdr.PacketNumberLen)
|
||||
l := hdr.Len()
|
||||
decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, hdr.PacketNumber, hdr.KeyPhase, data[:l])
|
||||
pn = opener.DecodePacketNumber(pn, pnLen)
|
||||
decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, 0, 0, nil, err
|
||||
}
|
||||
return hdr, decrypted, parseErr
|
||||
return pn, pnLen, kp, decrypted, parseErr
|
||||
}
|
||||
|
||||
func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (*wire.ShortHeader, error) {
|
||||
func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) {
|
||||
hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen
|
||||
if len(data) < hdrLen+4+16 {
|
||||
return nil, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
|
||||
return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
|
||||
}
|
||||
origPNBytes := make([]byte, 4)
|
||||
copy(origPNBytes, data[hdrLen:hdrLen+4])
|
||||
|
@ -178,15 +177,15 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (*wi
|
|||
data[hdrLen:hdrLen+4],
|
||||
)
|
||||
// 3. parse the header (and learn the actual length of the packet number)
|
||||
hdr, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
|
||||
l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
|
||||
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
|
||||
return nil, parseErr
|
||||
return l, pn, pnLen, kp, parseErr
|
||||
}
|
||||
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
|
||||
if hdr.PacketNumberLen != protocol.PacketNumberLen4 {
|
||||
copy(data[hdrLen+int(hdr.PacketNumberLen):hdrLen+4], origPNBytes[int(hdr.PacketNumberLen):])
|
||||
if pnLen != protocol.PacketNumberLen4 {
|
||||
copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):])
|
||||
}
|
||||
return hdr, parseErr
|
||||
return l, pn, pnLen, kp, parseErr
|
||||
}
|
||||
|
||||
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
|
||||
|
|
|
@ -75,7 +75,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
|
||||
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
|
||||
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
|
||||
_, _, err := unpacker.UnpackShortHeader(time.Now(), data)
|
||||
_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), data)
|
||||
Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
|
||||
Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19"))
|
||||
})
|
||||
|
@ -148,10 +148,11 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)),
|
||||
opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil),
|
||||
)
|
||||
hdr, data, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
|
||||
pn, pnLen, kp, data, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(321)))
|
||||
Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(pn).To(Equal(protocol.PacketNumber(321)))
|
||||
Expect(pnLen).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(kp).To(Equal(protocol.KeyPhaseOne))
|
||||
Expect(data).To(Equal([]byte("decrypted")))
|
||||
})
|
||||
|
||||
|
@ -163,7 +164,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
}
|
||||
_, hdrRaw := getHeader(extHdr)
|
||||
cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable)
|
||||
_, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
|
||||
_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
|
||||
Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable))
|
||||
})
|
||||
|
||||
|
@ -208,7 +209,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
|
||||
opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte(""), nil),
|
||||
)
|
||||
_, _, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
|
||||
_, _, _, _, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
|
||||
Expect(err).To(MatchError(&qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "empty packet",
|
||||
|
@ -273,7 +274,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
|
||||
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
|
||||
_, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
|
||||
_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
|
||||
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
|
||||
})
|
||||
|
||||
|
@ -312,7 +313,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
|
||||
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
|
||||
_, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
|
||||
_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
|
||||
Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
|
||||
})
|
||||
|
||||
|
|
|
@ -319,11 +319,12 @@ func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, p
|
|||
fs[i] = frame{Frame: f}
|
||||
}
|
||||
header := *transformShortHeader(hdr)
|
||||
hdrLen := 1 + hdr.DestConnectionID.Len() + int(hdr.PacketNumberLen)
|
||||
t.mutex.Lock()
|
||||
t.recordEvent(time.Now(), &eventPacketReceived{
|
||||
Header: header,
|
||||
Length: packetSize,
|
||||
PayloadLength: packetSize - hdr.Len(),
|
||||
PayloadLength: packetSize - protocol.ByteCount(hdrLen),
|
||||
Frames: fs,
|
||||
})
|
||||
t.mutex.Unlock()
|
||||
|
|
|
@ -11,8 +11,6 @@ import (
|
|||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
|
@ -529,9 +527,10 @@ var _ = Describe("Tracing", func() {
|
|||
})
|
||||
|
||||
It("records a received Short Header packet", func() {
|
||||
shdr := &wire.ShortHeader{
|
||||
shdr := &logging.ShortHeader{
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
||||
PacketNumber: 1337,
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
KeyPhase: protocol.KeyPhaseZero,
|
||||
}
|
||||
tracer.ReceivedShortHeaderPacket(
|
||||
|
@ -549,7 +548,7 @@ var _ = Describe("Tracing", func() {
|
|||
Expect(ev).To(HaveKey("raw"))
|
||||
raw := ev["raw"].(map[string]interface{})
|
||||
Expect(raw).To(HaveKeyWithValue("length", float64(789)))
|
||||
Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-shdr.Len())))
|
||||
Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-(1+8+3))))
|
||||
Expect(ev).To(HaveKey("header"))
|
||||
hdr := ev["header"].(map[string]interface{})
|
||||
Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT"))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue