diff --git a/internal/protocol/server_parameters.go b/internal/protocol/server_parameters.go index e89d5a9a..0b295f24 100644 --- a/internal/protocol/server_parameters.go +++ b/internal/protocol/server_parameters.go @@ -2,9 +2,11 @@ package protocol import "time" -// MaxPacketSize is the maximum packet size that we use for sending packets. -// It includes the QUIC packet header, but excludes the UDP and IP header. -const MaxPacketSize ByteCount = 1200 +// MaxPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets. +const MaxPacketSizeIPv4 = 1252 + +// MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets. +const MaxPacketSizeIPv6 = 1232 // NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet // This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames diff --git a/internal/wire/goaway_frame.go b/internal/wire/goaway_frame.go index fd5aca92..8a91e502 100644 --- a/internal/wire/goaway_frame.go +++ b/internal/wire/goaway_frame.go @@ -41,7 +41,7 @@ func ParseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame, return nil, err } - if reasonPhraseLen > uint16(protocol.MaxPacketSize) { + if reasonPhraseLen > uint16(protocol.MaxReceivePacketSize) { return nil, qerr.Error(qerr.InvalidGoawayData, "reason phrase too long") } diff --git a/packet_packer.go b/packet_packer.go index 535708af..22a0adc3 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "net" "sync" "github.com/lucas-clemente/quic-go/internal/ackhandler" @@ -41,6 +42,7 @@ type packetPacker struct { stopWaiting *wire.StopWaitingFrame ackFrame *wire.AckFrame omitConnectionID bool + maxPacketSize protocol.ByteCount hasSentPacket bool // has the packetPacker already sent a packet numNonRetransmittableAcks int } @@ -48,11 +50,25 @@ type packetPacker struct { func newPacketPacker(connectionID protocol.ConnectionID, initialPacketNumber protocol.PacketNumber, getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen, + remoteAddr net.Addr, // only used for determining the max packet size cryptoSetup handshake.CryptoSetup, streamFramer streamFrameSource, perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { + maxPacketSize := protocol.ByteCount(protocol.MinInitialPacketSize) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + // If ip is not an IPv4 address, To4 returns nil. + // Note that there might be some corner cases, where this is not correct. + // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. + if udpAddr.IP.To4() == nil { + maxPacketSize = protocol.MaxPacketSizeIPv6 + } else { + maxPacketSize = protocol.MaxPacketSizeIPv4 + } + } return &packetPacker{ cryptoSetup: cryptoSetup, connectionID: connectionID, @@ -61,6 +77,7 @@ func newPacketPacker(connectionID protocol.ConnectionID, streams: streamFramer, getPacketNumberLen: getPacketNumberLen, packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), + maxPacketSize: maxPacketSize, } } @@ -132,7 +149,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP if err != nil { return nil, err } - maxSize := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength + maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength // for gQUIC: add a STOP_WAITING for *every* retransmission if p.version.UsesStopWaitingFrames() { @@ -263,7 +280,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { p.stopWaiting.PacketNumberLen = header.PacketNumberLen } - maxSize := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength + maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel)) if err != nil { return nil, err @@ -312,7 +329,7 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { if err != nil { return nil, err } - maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength + maxLen := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength sf := p.streams.PopCryptoStreamFrame(maxLen) sf.DataLenPresent = false frames := []wire.Frame{sf} @@ -475,8 +492,8 @@ func (p *packetPacker) writeAndSealPacket( } } - if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > protocol.MaxPacketSize { - return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, protocol.MaxPacketSize) + if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize { + return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } raw = raw[0:buffer.Len()] diff --git a/packet_packer_test.go b/packet_packer_test.go index 3b2e1fa5..7ff37ba6 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "math" + "net" "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/ackhandler" @@ -52,6 +53,7 @@ func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce func (m *mockCryptoSetup) ConnectionState() ConnectionState { panic("not implemented") } var _ = Describe("Packet packer", func() { + const maxPacketSize protocol.ByteCount = 1357 var ( packer *packetPacker publicHeaderLen protocol.ByteCount @@ -69,15 +71,38 @@ var _ = Describe("Packet packer", func() { 0x1337, 1, func(protocol.PacketNumber) protocol.PacketNumberLen { return protocol.PacketNumberLen2 }, + &net.TCPAddr{}, &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, mockStreamFramer, protocol.PerspectiveServer, version, ) publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number - maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen + maxFrameSize = maxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen packer.hasSentPacket = true packer.version = version + packer.maxPacketSize = maxPacketSize + }) + + Context("determining the maximum packet size", func() { + It("uses the minimum initial size, if it can't determine if the remote address is IPv4 or IPv6", func() { + remoteAddr := &net.TCPAddr{} + packer = newPacketPacker(0x1337, 1, nil, remoteAddr, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) + Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MinInitialPacketSize)) + }) + + It("uses the maximum IPv4 packet size, if the remote address is IPv4", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(11, 12, 13, 14), Port: 1337} + packer = newPacketPacker(0x1337, 1, nil, remoteAddr, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) + Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MaxPacketSizeIPv4)) + }) + + It("uses the maximum IPv6 packet size, if the remote address is IPv6", func() { + ip := net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334") + remoteAddr := &net.UDPAddr{IP: ip, Port: 1337} + packer = newPacketPacker(0x1337, 1, nil, remoteAddr, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) + Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MaxPacketSizeIPv6)) + }) }) It("returns nil when no packet is queued", func() { @@ -449,7 +474,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) - Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) + Expect(p.raw).To(HaveLen(int(maxPacketSize))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -472,7 +497,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) - Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) + Expect(p.raw).To(HaveLen(int(maxPacketSize))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -677,7 +702,7 @@ var _ = Describe("Packet packer", func() { Frames: []wire.Frame{ &wire.StreamFrame{ StreamID: 1, - Data: bytes.Repeat([]byte{'f'}, int(protocol.MaxPacketSize-5)), + Data: bytes.Repeat([]byte{'f'}, int(maxPacketSize-5)), }, }, } @@ -771,7 +796,7 @@ var _ = Describe("Packet packer", func() { var frames []wire.Frame var totalLen protocol.ByteCount // pack a bunch of control frames, such that the packet is way bigger than a single packet - for i := 0; totalLen < protocol.MaxPacketSize*3/2; i++ { + for i := 0; totalLen < maxPacketSize*3/2; i++ { f := &wire.MaxStreamDataFrame{StreamID: protocol.StreamID(i), ByteOffset: protocol.ByteCount(i)} frames = append(frames, f) totalLen += f.Length(packer.version) @@ -789,7 +814,7 @@ var _ = Describe("Packet packer", func() { Expect(packets[1].frames[1:]).To(Equal(frames[len(packets[0].frames)-1:])) // check that the first packet was filled up as far as possible: // if the first frame (after the STOP_WAITING) was packed into the first packet, it would have overflown the MaxPacketSize - Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", protocol.MaxPacketSize)) + Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize)) }) It("splits a STREAM frame that doesn't fit", func() { @@ -798,7 +823,7 @@ var _ = Describe("Packet packer", func() { Frames: []wire.Frame{&wire.StreamFrame{ StreamID: 42, Offset: 1337, - Data: bytes.Repeat([]byte{'a'}, int(protocol.MaxPacketSize)*3/2), + Data: bytes.Repeat([]byte{'a'}, int(maxPacketSize)*3/2), }}, }) Expect(err).ToNot(HaveOccurred()) @@ -815,15 +840,15 @@ var _ = Describe("Packet packer", func() { Expect(sf2.StreamID).To(Equal(protocol.StreamID(42))) Expect(sf2.Offset).To(Equal(protocol.ByteCount(1337) + sf1.DataLen())) Expect(sf2.DataLenPresent).To(BeFalse()) - Expect(sf1.DataLen() + sf2.DataLen()).To(Equal(protocol.MaxPacketSize * 3 / 2)) - Expect(packets[0].raw).To(HaveLen(int(protocol.MaxPacketSize))) + Expect(sf1.DataLen() + sf2.DataLen()).To(Equal(maxPacketSize * 3 / 2)) + Expect(packets[0].raw).To(HaveLen(int(maxPacketSize))) }) It("packs two packets for retransmission if the original packet contained many STREAM frames", func() { var frames []wire.Frame var totalLen protocol.ByteCount // pack a bunch of control frames, such that the packet is way bigger than a single packet - for i := 0; totalLen < protocol.MaxPacketSize*3/2; i++ { + for i := 0; totalLen < maxPacketSize*3/2; i++ { f := &wire.StreamFrame{ StreamID: protocol.StreamID(i), Data: []byte("foobar"), @@ -845,7 +870,7 @@ var _ = Describe("Packet packer", func() { Expect(packets[1].frames[1:]).To(Equal(frames[len(packets[0].frames)-1:])) // check that the first packet was filled up as far as possible: // if the first frame (after the STOP_WAITING) was packed into the first packet, it would have overflown the MaxPacketSize - Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", protocol.MaxPacketSize-protocol.MinStreamFrameSize)) + Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize-protocol.MinStreamFrameSize)) }) It("correctly sets the DataLenPresent on STREAM frames", func() { diff --git a/session.go b/session.go index 069d0b44..d9e5bf13 100644 --- a/session.go +++ b/session.go @@ -339,6 +339,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.packer = newPacketPacker(s.connectionID, initialPacketNumber, s.sentPacketHandler.GetPacketNumberLen, + s.RemoteAddr(), s.cryptoSetup, s.streamFramer, s.perspective, @@ -1067,7 +1068,6 @@ func (s *session) LocalAddr() net.Addr { return s.conn.LocalAddr() } -// RemoteAddr returns the net.Addr of the client func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } diff --git a/session_test.go b/session_test.go index dffdeff2..019f6562 100644 --- a/session_test.go +++ b/session_test.go @@ -1051,7 +1051,7 @@ var _ = Describe("Session", func() { sess.packer.version = versionIETFFrames f := &wire.StreamFrame{ StreamID: 0x5, - Data: bytes.Repeat([]byte{'b'}, int(protocol.MaxPacketSize)*3/2), + Data: bytes.Repeat([]byte{'b'}, int(protocol.MaxPacketSizeIPv4)*3/2), } sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ Frames: []wire.Frame{f},