correctly pack packets with sealers with different overheads

This commit is contained in:
Marten Seemann 2017-09-04 10:17:44 +08:00
parent 7cbef3cb1b
commit 8bcabe0ba9
9 changed files with 89 additions and 109 deletions

View file

@ -333,16 +333,16 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil { if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.sealForwardSecure return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
} else if h.secureAEAD != nil { } else if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.sealSecure return protocol.EncryptionSecure, h.secureAEAD
} else { } else {
return protocol.EncryptionUnencrypted, h.sealUnencrypted return protocol.EncryptionUnencrypted, h.nullAEAD
} }
} }
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.sealUnencrypted return protocol.EncryptionUnencrypted, h.nullAEAD
} }
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
@ -351,33 +351,21 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
switch encLevel { switch encLevel {
case protocol.EncryptionUnencrypted: case protocol.EncryptionUnencrypted:
return h.sealUnencrypted, nil return h.nullAEAD, nil
case protocol.EncryptionSecure: case protocol.EncryptionSecure:
if h.secureAEAD == nil { if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no secureAEAD") return nil, errors.New("CryptoSetupClient: no secureAEAD")
} }
return h.sealSecure, nil return h.secureAEAD, nil
case protocol.EncryptionForwardSecure: case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil { if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD") return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
} }
return h.sealForwardSecure, nil return h.forwardSecureAEAD, nil
} }
return nil, errors.New("CryptoSetupClient: no encryption level specified") return nil, errors.New("CryptoSetupClient: no encryption level specified")
} }
func (h *cryptoSetupClient) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.nullAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupClient) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupClient) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupClient) DiversificationNonce() []byte { func (h *cryptoSetupClient) DiversificationNonce() []byte {
panic("not needed for cryptoSetupClient") panic("not needed for cryptoSetupClient")
} }

View file

@ -685,16 +685,16 @@ var _ = Describe("Client Crypto Setup", func() {
Context("null encryption", func() { Context("null encryption", func() {
It("is used initially", func() { It("is used initially", func() {
enc, seal := cs.GetSealer() enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal(foobarFNVSigned)) Expect(d).To(Equal(foobarFNVSigned))
}) })
It("is used for the crypto stream", func() { It("is used for the crypto stream", func() {
enc, seal := cs.GetSealerForCryptoStream() enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal(foobarFNVSigned)) Expect(d).To(Equal(foobarFNVSigned))
}) })
@ -735,9 +735,9 @@ var _ = Describe("Client Crypto Setup", func() {
It("is used immediately when available", func() { It("is used immediately when available", func() {
doCompleteREJ() doCompleteREJ()
cs.receivedSecurePacket = false cs.receivedSecurePacket = false
enc, seal := cs.GetSealer() enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionSecure)) Expect(enc).To(Equal(protocol.EncryptionSecure))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar normal sec"))) Expect(d).To(Equal([]byte("foobar normal sec")))
}) })
@ -759,9 +759,9 @@ var _ = Describe("Client Crypto Setup", func() {
It("is not used for the crypto stream", func() { It("is not used for the crypto stream", func() {
doCompleteREJ() doCompleteREJ()
enc, seal := cs.GetSealerForCryptoStream() enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal(foobarFNVSigned)) Expect(d).To(Equal(foobarFNVSigned))
}) })
}) })
@ -772,61 +772,61 @@ var _ = Describe("Client Crypto Setup", func() {
_, enc, err := cs.Open(nil, []byte("forward secure encrypted"), 0, []byte{}) _, enc, err := cs.Open(nil, []byte("forward secure encrypted"), 0, []byte{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
enc, seal := cs.GetSealer() enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec"))) Expect(d).To(Equal([]byte("foobar forward sec")))
}) })
It("is not used for the crypto stream", func() { It("is not used for the crypto stream", func() {
doSHLO() doSHLO()
enc, seal := cs.GetSealerForCryptoStream() enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal(foobarFNVSigned)) Expect(d).To(Equal(foobarFNVSigned))
}) })
}) })
Context("forcing encryption levels", func() { Context("forcing encryption levels", func() {
It("forces null encryption", func() { It("forces null encryption", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal(foobarFNVSigned)) Expect(d).To(Equal(foobarFNVSigned))
}) })
It("forces initial encryption", func() { It("forces initial encryption", func() {
doCompleteREJ() doCompleteREJ()
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar normal sec"))) Expect(d).To(Equal([]byte("foobar normal sec")))
}) })
It("errors of no AEAD for initial encryption is available", func() { It("errors of no AEAD for initial encryption is available", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).To(MatchError("CryptoSetupClient: no secureAEAD")) Expect(err).To(MatchError("CryptoSetupClient: no secureAEAD"))
Expect(seal).To(BeNil()) Expect(sealer).To(BeNil())
}) })
It("forces forward-secure encryption", func() { It("forces forward-secure encryption", func() {
doSHLO() doSHLO()
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec"))) Expect(d).To(Equal([]byte("foobar forward sec")))
}) })
It("errors of no AEAD for forward-secure encryption is available", func() { It("errors of no AEAD for forward-secure encryption is available", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).To(MatchError("CryptoSetupClient: no forwardSecureAEAD")) Expect(err).To(MatchError("CryptoSetupClient: no forwardSecureAEAD"))
Expect(seal).To(BeNil()) Expect(sealer).To(BeNil())
}) })
It("errors if no encryption level is specified", func() { It("errors if no encryption level is specified", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified)
Expect(err).To(MatchError("CryptoSetupClient: no encryption level specified")) Expect(err).To(MatchError("CryptoSetupClient: no encryption level specified"))
Expect(seal).To(BeNil()) Expect(sealer).To(BeNil())
}) })
}) })
}) })

View file

@ -226,18 +226,18 @@ func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil { if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.sealForwardSecure return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
} }
return protocol.EncryptionUnencrypted, h.sealUnencrypted return protocol.EncryptionUnencrypted, h.nullAEAD
} }
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.secureAEAD != nil { if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.sealSecure return protocol.EncryptionSecure, h.secureAEAD
} }
return protocol.EncryptionUnencrypted, h.sealUnencrypted return protocol.EncryptionUnencrypted, h.nullAEAD
} }
func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
@ -246,33 +246,21 @@ func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.Encry
switch encLevel { switch encLevel {
case protocol.EncryptionUnencrypted: case protocol.EncryptionUnencrypted:
return h.sealUnencrypted, nil return h.nullAEAD, nil
case protocol.EncryptionSecure: case protocol.EncryptionSecure:
if h.secureAEAD == nil { if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no secureAEAD") return nil, errors.New("CryptoSetupServer: no secureAEAD")
} }
return h.sealSecure, nil return h.secureAEAD, nil
case protocol.EncryptionForwardSecure: case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil { if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD") return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD")
} }
return h.sealForwardSecure, nil return h.forwardSecureAEAD, nil
} }
return nil, errors.New("CryptoSetupServer: no encryption level specified") return nil, errors.New("CryptoSetupServer: no encryption level specified")
} }
func (h *cryptoSetupServer) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.nullAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupServer) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool { func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool {
if _, ok := cryptoData[TagPUBS]; !ok { if _, ok := cryptoData[TagPUBS]; !ok {
return true return true

View file

@ -569,16 +569,16 @@ var _ = Describe("Server Crypto Setup", func() {
Context("null encryption", func() { Context("null encryption", func() {
It("is used initially", func() { It("is used initially", func() {
enc, seal := cs.GetSealer() enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal(foobarServerFNVSigned)) Expect(d).To(Equal(foobarServerFNVSigned))
}) })
It("is used for crypto stream", func() { It("is used for crypto stream", func() {
enc, seal := cs.GetSealerForCryptoStream() enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal(foobarServerFNVSigned)) Expect(d).To(Equal(foobarServerFNVSigned))
}) })
@ -618,9 +618,9 @@ var _ = Describe("Server Crypto Setup", func() {
It("is not used after CHLO", func() { It("is not used after CHLO", func() {
doCHLO() doCHLO()
enc, seal := cs.GetSealer() enc, sealer := cs.GetSealer()
Expect(enc).ToNot(Equal(protocol.EncryptionUnencrypted)) Expect(enc).ToNot(Equal(protocol.EncryptionUnencrypted))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).ToNot(Equal(foobarServerFNVSigned)) Expect(d).ToNot(Equal(foobarServerFNVSigned))
}) })
}) })
@ -645,9 +645,9 @@ var _ = Describe("Server Crypto Setup", func() {
It("is used for crypto stream", func() { It("is used for crypto stream", func() {
doCHLO() doCHLO()
enc, seal := cs.GetSealerForCryptoStream() enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionSecure)) Expect(enc).To(Equal(protocol.EncryptionSecure))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar normal sec"))) Expect(d).To(Equal([]byte("foobar normal sec")))
}) })
}) })
@ -655,9 +655,9 @@ var _ = Describe("Server Crypto Setup", func() {
Context("forward secure encryption", func() { Context("forward secure encryption", func() {
It("is used after the CHLO", func() { It("is used after the CHLO", func() {
doCHLO() doCHLO()
enc, seal := cs.GetSealer() enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec"))) Expect(d).To(Equal([]byte("foobar forward sec")))
}) })
@ -671,31 +671,31 @@ var _ = Describe("Server Crypto Setup", func() {
Context("forcing encryption levels", func() { Context("forcing encryption levels", func() {
It("forces null encryption", func() { It("forces null encryption", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal(foobarServerFNVSigned)) Expect(d).To(Equal(foobarServerFNVSigned))
}) })
It("forces initial encryption", func() { It("forces initial encryption", func() {
doCHLO() doCHLO()
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar normal sec"))) Expect(d).To(Equal([]byte("foobar normal sec")))
}) })
It("errors of no AEAD for initial encryption is available", func() { It("errors of no AEAD for initial encryption is available", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).To(MatchError("CryptoSetupServer: no secureAEAD")) Expect(err).To(MatchError("CryptoSetupServer: no secureAEAD"))
Expect(seal).To(BeNil()) Expect(sealer).To(BeNil())
}) })
It("forces forward-secure encryption", func() { It("forces forward-secure encryption", func() {
doCHLO() doCHLO()
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure) sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
d := seal(nil, []byte("foobar"), 0, []byte{}) d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec"))) Expect(d).To(Equal([]byte("foobar forward sec")))
}) })

View file

@ -3,7 +3,10 @@ package handshake
import "github.com/lucas-clemente/quic-go/internal/protocol" import "github.com/lucas-clemente/quic-go/internal/protocol"
// Sealer seals a packet // Sealer seals a packet
type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte type Sealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
Overhead() int
}
// CryptoSetup is a crypto setup // CryptoSetup is a crypto setup
type CryptoSetup interface { type CryptoSetup interface {

View file

@ -6,9 +6,6 @@ import "time"
// This is the value used by Chromium for a QUIC packet sent using IPv6 (for IPv4 it would be 1370) // This is the value used by Chromium for a QUIC packet sent using IPv6 (for IPv4 it would be 1370)
const MaxPacketSize ByteCount = 1350 const MaxPacketSize ByteCount = 1350
// MaxFrameAndPublicHeaderSize is the maximum size of a QUIC frame plus PublicHeader
const MaxFrameAndPublicHeaderSize = MaxPacketSize - 12 /*crypto signature*/
// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet // NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet
// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames // This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames
const NonForwardSecurePacketSizeReduction = 50 const NonForwardSecurePacketSizeReduction = 50

View file

@ -134,7 +134,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
p.stopWaiting.PacketNumberLen = publicHeader.PacketNumberLen p.stopWaiting.PacketNumberLen = publicHeader.PacketNumberLen
} }
maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength maxSize := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - publicHeaderLength
payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel)) payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel))
if err != nil { if err != nil {
return nil, err return nil, err
@ -170,7 +170,7 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
maxLen := protocol.MaxFrameAndPublicHeaderSize - protocol.NonForwardSecurePacketSizeReduction - publicHeaderLength maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - publicHeaderLength
frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)}
raw, err := p.writeAndSealPacket(publicHeader, frames, sealer) raw, err := p.writeAndSealPacket(publicHeader, frames, sealer)
if err != nil { if err != nil {
@ -303,13 +303,13 @@ func (p *packetPacker) writeAndSealPacket(
return nil, err return nil, err
} }
} }
if protocol.ByteCount(buffer.Len()+12) > protocol.MaxPacketSize { if protocol.ByteCount(buffer.Len()+sealer.Overhead()) > protocol.MaxPacketSize {
return nil, errors.New("PacketPacker BUG: packet too large") return nil, errors.New("PacketPacker BUG: packet too large")
} }
raw = raw[0:buffer.Len()] raw = raw[0:buffer.Len()]
_ = sealer(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], publicHeader.PacketNumber, raw[:payloadStartIndex]) _ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], publicHeader.PacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+12] raw = raw[0 : buffer.Len()+sealer.Overhead()]
num := p.packetNumberGenerator.Pop() num := p.packetNumberGenerator.Pop()
if num != publicHeader.PacketNumber { if num != publicHeader.PacketNumber {

View file

@ -13,6 +13,16 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
type mockSealer struct{}
func (s *mockSealer) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return append(src, bytes.Repeat([]byte{0}, 12)...)
}
func (s *mockSealer) Overhead() int { return 12 }
var _ handshake.Sealer = &mockSealer{}
type mockCryptoSetup struct { type mockCryptoSetup struct {
handleErr error handleErr error
divNonce []byte divNonce []byte
@ -20,6 +30,8 @@ type mockCryptoSetup struct {
encLevelSealCrypto protocol.EncryptionLevel encLevelSealCrypto protocol.EncryptionLevel
} }
var _ handshake.CryptoSetup = &mockCryptoSetup{}
func (m *mockCryptoSetup) HandleCryptoStream() error { func (m *mockCryptoSetup) HandleCryptoStream() error {
return m.handleErr return m.handleErr
} }
@ -27,25 +39,17 @@ func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumb
return nil, protocol.EncryptionUnspecified, nil return nil, protocol.EncryptionUnspecified, nil
} }
func (m *mockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) { func (m *mockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) {
return m.encLevelSeal, func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { return m.encLevelSeal, &mockSealer{}
return append(src, bytes.Repeat([]byte{0}, 12)...)
}
} }
func (m *mockCryptoSetup) GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer) { func (m *mockCryptoSetup) GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer) {
return m.encLevelSealCrypto, func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { return m.encLevelSealCrypto, &mockSealer{}
return append(src, bytes.Repeat([]byte{0}, 12)...)
}
} }
func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error) { func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error) {
return func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { return &mockSealer{}, nil
return append(src, bytes.Repeat([]byte{0}, 12)...)
}, nil
} }
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce } func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce } func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
var _ handshake.CryptoSetup = &mockCryptoSetup{}
var _ = Describe("Packet packer", func() { var _ = Describe("Packet packer", func() {
var ( var (
packer *packetPacker packer *packetPacker
@ -75,7 +79,7 @@ var _ = Describe("Packet packer", func() {
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
} }
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
maxFrameSize = protocol.MaxFrameAndPublicHeaderSize - publicHeaderLen maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
packer.version = protocol.VersionWhatever packer.version = protocol.VersionWhatever
}) })
@ -261,7 +265,7 @@ var _ = Describe("Packet packer", func() {
f := &wire.AckFrame{LargestAcked: 1} f := &wire.AckFrame{LargestAcked: 1}
b := &bytes.Buffer{} b := &bytes.Buffer{}
f.Write(b, protocol.VersionWhatever) f.Write(b, protocol.VersionWhatever)
maxFramesPerPacket := int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen) / b.Len() maxFramesPerPacket := int(maxFrameSize) / b.Len()
var controlFrames []wire.Frame var controlFrames []wire.Frame
for i := 0; i < maxFramesPerPacket; i++ { for i := 0; i < maxFramesPerPacket; i++ {
controlFrames = append(controlFrames, f) controlFrames = append(controlFrames, f)
@ -280,7 +284,7 @@ var _ = Describe("Packet packer", func() {
StreamID: 0x1337, StreamID: 0x1337,
} }
minLength, _ := blockedFrame.MinLength(0) minLength, _ := blockedFrame.MinLength(0)
maxFramesPerPacket := int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen) / int(minLength) maxFramesPerPacket := int(maxFrameSize) / int(minLength)
var controlFrames []wire.Frame var controlFrames []wire.Frame
for i := 0; i < maxFramesPerPacket+10; i++ { for i := 0; i < maxFramesPerPacket+10; i++ {
controlFrames = append(controlFrames, blockedFrame) controlFrames = append(controlFrames, blockedFrame)
@ -333,7 +337,7 @@ var _ = Describe("Packet packer", func() {
}) })
It("correctly handles a stream frame with one byte less than maximum size", func() { It("correctly handles a stream frame with one byte less than maximum size", func() {
maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLen - (1 + 1 + 2) - 1 maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - 1
f1 := &wire.StreamFrame{ f1 := &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Offset: 1, Offset: 1,
@ -395,7 +399,7 @@ var _ = Describe("Packet packer", func() {
Offset: 1, Offset: 1,
} }
minLength, _ := f.MinLength(0) minLength, _ := f.MinLength(0)
maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLen - minLength maxStreamFrameDataLen := maxFrameSize - minLength
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200) f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200)
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
@ -414,7 +418,7 @@ var _ = Describe("Packet packer", func() {
}) })
It("packs 2 stream frames that are too big for one packet correctly", func() { It("packs 2 stream frames that are too big for one packet correctly", func() {
maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLen - (1 + 1 + 2) maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2)
f1 := &wire.StreamFrame{ f1 := &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100), Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100),
@ -454,7 +458,7 @@ var _ = Describe("Packet packer", func() {
Offset: 1, Offset: 1,
} }
minLength, _ := f.MinLength(0) minLength, _ := f.MinLength(0)
f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -468,7 +472,7 @@ var _ = Describe("Packet packer", func() {
Offset: 1, Offset: 1,
} }
minLength, _ := f.MinLength(0) minLength, _ := f.MinLength(0)
f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)

View file

@ -218,7 +218,7 @@ var _ = Describe("Stream Framer", func() {
} }
framer.AddFrameForRetransmission(origFrame) framer.AddFrameForRetransmission(origFrame)
minFrameDataLen := protocol.MaxFrameAndPublicHeaderSize minFrameDataLen := protocol.MaxPacketSize
for i := 0; i < 30; i++ { for i := 0; i < 30; i++ {
frames, currentLen := framer.maybePopFramesForRetransmission(protocol.ByteCount(i)) frames, currentLen := framer.maybePopFramesForRetransmission(protocol.ByteCount(i))