diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index 5dedf1f5..e0c946cb 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -9,6 +9,8 @@ import ( ) type frameParser struct { + ackDelayExponent uint8 + version protocol.VersionNumber } @@ -19,7 +21,7 @@ func NewFrameParser(v protocol.VersionNumber) FrameParser { // ParseNextFrame parses the next frame // It skips PADDING frames. -func (p *frameParser) ParseNext(r *bytes.Reader) (Frame, error) { +func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) { for r.Len() != 0 { typeByte, _ := r.ReadByte() if typeByte == 0x0 { // PADDING frame @@ -27,12 +29,12 @@ func (p *frameParser) ParseNext(r *bytes.Reader) (Frame, error) { } r.UnreadByte() - return p.parseFrame(r, typeByte) + return p.parseFrame(r, typeByte, encLevel) } return nil, nil } -func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte) (Frame, error) { +func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protocol.EncryptionLevel) (Frame, error) { var frame Frame var err error if typeByte&0xf8 == 0x8 { @@ -46,7 +48,11 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte) (Frame, error) case 0x1: frame, err = parsePingFrame(r, p.version) case 0x2, 0x3: - frame, err = parseAckFrame(r, protocol.AckDelayExponent, p.version) + ackDelayExponent := p.ackDelayExponent + if encLevel != protocol.Encryption1RTT { + ackDelayExponent = protocol.DefaultAckDelayExponent + } + frame, err = parseAckFrame(r, ackDelayExponent, p.version) case 0x4: frame, err = parseResetStreamFrame(r, p.version) case 0x5: @@ -85,3 +91,7 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte) (Frame, error) } return frame, nil } + +func (p *frameParser) SetAckDelayExponent(exp uint8) { + p.ackDelayExponent = exp +} diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index adca648d..44a8987a 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "time" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" @@ -21,7 +22,7 @@ var _ = Describe("Frame parsing", func() { }) It("returns nil if there's nothing more to read", func() { - f, err := parser.ParseNext(bytes.NewReader(nil)) + f, err := parser.ParseNext(bytes.NewReader(nil), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeNil()) }) @@ -29,14 +30,14 @@ var _ = Describe("Frame parsing", func() { It("skips PADDING frames", func() { buf.Write([]byte{0}) // PADDING frame (&PingFrame{}).Write(buf, versionIETFFrames) - f, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + f, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(f).To(Equal(&PingFrame{})) }) It("handles PADDING at the end", func() { r := bytes.NewReader([]byte{0, 0, 0}) - f, err := parser.ParseNext(r) + f, err := parser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeNil()) Expect(r.Len()).To(BeZero()) @@ -46,13 +47,39 @@ var _ = Describe("Frame parsing", func() { f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) Expect(frame.(*AckFrame).LargestAcked()).To(Equal(protocol.PacketNumber(0x13))) }) + It("uses the custom ack delay exponent for 1RTT packets", func() { + parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: time.Second, + } + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + // The ACK frame is always written using the protocol.AckDelayExponent. + // That's why we expect a different value when parsing. + Expect(frame.(*AckFrame).DelayTime).To(Equal(4 * time.Second)) + }) + + It("uses the default ack delay exponent for non-1RTT packets", func() { + parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: time.Second, + } + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.(*AckFrame).DelayTime).To(Equal(time.Second)) + }) + It("unpacks RESET_STREAM frames", func() { f := &ResetStreamFrame{ StreamID: 0xdeadbeef, @@ -61,7 +88,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -71,7 +98,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -83,7 +110,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -93,7 +120,7 @@ var _ = Describe("Frame parsing", func() { f := &NewTokenFrame{Token: []byte("foobar")} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -108,7 +135,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -121,7 +148,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -134,7 +161,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -147,7 +174,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -157,7 +184,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -169,7 +196,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -182,7 +209,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -195,7 +222,7 @@ var _ = Describe("Frame parsing", func() { } buf := &bytes.Buffer{} Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -204,7 +231,7 @@ var _ = Describe("Frame parsing", func() { f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} buf := &bytes.Buffer{} Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -213,7 +240,7 @@ var _ = Describe("Frame parsing", func() { f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -224,7 +251,7 @@ var _ = Describe("Frame parsing", func() { f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -239,13 +266,13 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("errors on invalid type", func() { - _, err := parser.ParseNext(bytes.NewReader([]byte{0x42})) + _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0x42")) }) @@ -256,7 +283,7 @@ var _ = Describe("Frame parsing", func() { } b := &bytes.Buffer{} f.Write(b, versionIETFFrames) - _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2])) + _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2]), protocol.Encryption1RTT) Expect(err).To(HaveOccurred()) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidFrameData)) }) diff --git a/internal/wire/interface.go b/internal/wire/interface.go index d9837a6a..99fdc80f 100644 --- a/internal/wire/interface.go +++ b/internal/wire/interface.go @@ -14,5 +14,6 @@ type Frame interface { // A FrameParser parses QUIC frames, one by one. type FrameParser interface { - ParseNext(r *bytes.Reader) (Frame, error) + ParseNext(*bytes.Reader, protocol.EncryptionLevel) (Frame, error) + SetAckDelayExponent(uint8) } diff --git a/packet_packer_test.go b/packet_packer_test.go index 55a7f193..dda434bd 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -822,7 +822,7 @@ var _ = Describe("Packet packer", func() { Expect(firstPayloadByte).To(Equal(byte(0))) // ... followed by the stream frame frameParser := wire.NewFrameParser(packer.version) - frame, err := frameParser.ParseNext(r) + frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(r.Len()).To(BeZero()) diff --git a/session.go b/session.go index 4122d670..ba79ef65 100644 --- a/session.go +++ b/session.go @@ -553,7 +553,7 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time r := bytes.NewReader(packet.data) var isRetransmittable bool for { - frame, err := s.frameParser.ParseNext(r) + frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel) if err != nil { return err } @@ -814,6 +814,7 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete s.peerParams = params s.streamsMap.UpdateLimits(params) s.packer.HandleTransportParameters(params) + s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) // the crypto stream is the only open stream at this moment // so we don't need to update stream flow control windows