From e345270e848a2333541c39da80b4fd3ed00fe62d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 10 May 2016 23:40:22 +0700 Subject: [PATCH] use ByteCount type for Frame, Packet and PublicHeader lengths --- frames/ack_frame.go | 4 +-- frames/ack_frame_test.go | 4 +-- frames/blocked_frame.go | 2 +- frames/blocked_frame_test.go | 2 +- frames/connection_close_frame.go | 4 +-- frames/connection_close_frame_test.go | 2 +- frames/frame.go | 2 +- frames/ping_frame.go | 2 +- frames/ping_frame_test.go | 2 +- frames/rst_stream_frame.go | 2 +- frames/stop_waiting_frame.go | 2 +- frames/stream_frame.go | 8 +++--- frames/stream_frame_test.go | 2 +- frames/window_update_frame.go | 2 +- packet_packer.go | 12 ++++---- packet_packer_test.go | 40 +++++++++++++-------------- protocol/protocol.go | 2 +- public_header.go | 6 ++-- public_header_test.go | 6 ++-- 19 files changed, 53 insertions(+), 53 deletions(-) diff --git a/frames/ack_frame.go b/frames/ack_frame.go index a271a938..97266010 100644 --- a/frames/ack_frame.go +++ b/frames/ack_frame.go @@ -94,13 +94,13 @@ func (f *AckFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, pa } // MinLength of a written frame -func (f *AckFrame) MinLength() int { +func (f *AckFrame) MinLength() protocol.ByteCount { l := 1 + 1 + 6 + 2 + 1 + 1 + 4 l += (1 + 2) * 0 /* TODO: num_timestamps */ if f.HasNACK() { l += 1 + (6+1)*len(f.NackRanges) } - return l + return protocol.ByteCount(l) } // HasNACK returns if the frame has NACK ranges diff --git a/frames/ack_frame_test.go b/frames/ack_frame_test.go index be09121a..ac583ebe 100644 --- a/frames/ack_frame_test.go +++ b/frames/ack_frame_test.go @@ -404,7 +404,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 1, } f.Write(b, 1, protocol.PacketNumberLen6, 32) - Expect(f.MinLength()).To(Equal(b.Len())) + Expect(f.MinLength()).To(Equal(protocol.ByteCount(b.Len()))) }) It("has proper min length with nack ranges", func() { @@ -415,7 +415,7 @@ var _ = Describe("AckFrame", func() { } err := f.Write(b, 1, protocol.PacketNumberLen6, 32) Expect(err).ToNot(HaveOccurred()) - Expect(f.MinLength()).To(Equal(b.Len())) + Expect(f.MinLength()).To(Equal(protocol.ByteCount(b.Len()))) }) }) }) diff --git a/frames/blocked_frame.go b/frames/blocked_frame.go index c3bc906c..f3769683 100644 --- a/frames/blocked_frame.go +++ b/frames/blocked_frame.go @@ -26,7 +26,7 @@ func (f *BlockedFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber } // MinLength of a written frame -func (f *BlockedFrame) MinLength() int { +func (f *BlockedFrame) MinLength() protocol.ByteCount { return 1 + 4 } diff --git a/frames/blocked_frame_test.go b/frames/blocked_frame_test.go index c1f5db4a..c0ea23c3 100644 --- a/frames/blocked_frame_test.go +++ b/frames/blocked_frame_test.go @@ -28,7 +28,7 @@ var _ = Describe("BlockedFrame", func() { It("has the correct min length", func() { frame := BlockedFrame{StreamID: 3} - Expect(frame.MinLength()).To(Equal(5)) + Expect(frame.MinLength()).To(Equal(protocol.ByteCount(5))) }) }) }) diff --git a/frames/connection_close_frame.go b/frames/connection_close_frame.go index 7de367d9..b65d6d6d 100644 --- a/frames/connection_close_frame.go +++ b/frames/connection_close_frame.go @@ -47,8 +47,8 @@ func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) { } // MinLength of a written frame -func (f *ConnectionCloseFrame) MinLength() int { - return 1 + 4 + 2 + len(f.ReasonPhrase) +func (f *ConnectionCloseFrame) MinLength() protocol.ByteCount { + return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)) } // Write writes an CONNECTION_CLOSE frame. diff --git a/frames/connection_close_frame_test.go b/frames/connection_close_frame_test.go index 68825ec9..dd326789 100644 --- a/frames/connection_close_frame_test.go +++ b/frames/connection_close_frame_test.go @@ -78,7 +78,7 @@ var _ = Describe("ConnectionCloseFrame", func() { ReasonPhrase: "foobar", } f.Write(b, 1, protocol.PacketNumberLen6, 0) - Expect(f.MinLength()).To(Equal(b.Len())) + Expect(f.MinLength()).To(Equal(protocol.ByteCount(b.Len()))) }) }) diff --git a/frames/frame.go b/frames/frame.go index 077ec86a..7bf2c0e3 100644 --- a/frames/frame.go +++ b/frames/frame.go @@ -9,5 +9,5 @@ import ( // A Frame in QUIC type Frame interface { Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen, version protocol.VersionNumber) error - MinLength() int + MinLength() protocol.ByteCount } diff --git a/frames/ping_frame.go b/frames/ping_frame.go index e860c5b7..ed69be85 100644 --- a/frames/ping_frame.go +++ b/frames/ping_frame.go @@ -28,6 +28,6 @@ func (f *PingFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, p } // MinLength of a written frame -func (f *PingFrame) MinLength() int { +func (f *PingFrame) MinLength() protocol.ByteCount { return 1 } diff --git a/frames/ping_frame_test.go b/frames/ping_frame_test.go index 8e706657..dae243c8 100644 --- a/frames/ping_frame_test.go +++ b/frames/ping_frame_test.go @@ -28,7 +28,7 @@ var _ = Describe("PingFrame", func() { It("has the correct min length", func() { frame := PingFrame{} - Expect(frame.MinLength()).To(Equal(1)) + Expect(frame.MinLength()).To(Equal(protocol.ByteCount(1))) }) }) }) diff --git a/frames/rst_stream_frame.go b/frames/rst_stream_frame.go index 5e4cfb5e..4b6d35e0 100644 --- a/frames/rst_stream_frame.go +++ b/frames/rst_stream_frame.go @@ -20,7 +20,7 @@ func (f *RstStreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumb } // MinLength of a written frame -func (f *RstStreamFrame) MinLength() int { +func (f *RstStreamFrame) MinLength() protocol.ByteCount { panic("RstStreamFrame: Write not yet implemented") } diff --git a/frames/stop_waiting_frame.go b/frames/stop_waiting_frame.go index a8ea83db..01c4e786 100644 --- a/frames/stop_waiting_frame.go +++ b/frames/stop_waiting_frame.go @@ -31,7 +31,7 @@ func (f *StopWaitingFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNu } // MinLength of a written frame -func (f *StopWaitingFrame) MinLength() int { +func (f *StopWaitingFrame) MinLength() protocol.ByteCount { return 1 + 1 + 6 } diff --git a/frames/stream_frame.go b/frames/stream_frame.go index 19c7d33d..8f5d9323 100644 --- a/frames/stream_frame.go +++ b/frames/stream_frame.go @@ -91,20 +91,20 @@ func (f *StreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, } // MinLength of a written frame -func (f *StreamFrame) MinLength() int { +func (f *StreamFrame) MinLength() protocol.ByteCount { return 1 + 4 + 8 + 2 + 1 } // MaybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(n), nil is returned and nothing is modified. -func (f *StreamFrame) MaybeSplitOffFrame(n int) *StreamFrame { - if n >= f.MinLength()-1+len(f.Data) { +func (f *StreamFrame) MaybeSplitOffFrame(n protocol.ByteCount) *StreamFrame { + if n >= f.MinLength()-1+protocol.ByteCount(len(f.Data)) { return nil } n -= f.MinLength() - 1 defer func() { f.Data = f.Data[n:] - f.Offset += protocol.ByteCount(n) + f.Offset += n }() return &StreamFrame{ diff --git a/frames/stream_frame_test.go b/frames/stream_frame_test.go index b17fa9da..1ed7790d 100644 --- a/frames/stream_frame_test.go +++ b/frames/stream_frame_test.go @@ -59,7 +59,7 @@ var _ = Describe("StreamFrame", func() { Offset: 1, } f.Write(b, 1, protocol.PacketNumberLen6, 0) - Expect(f.MinLength()).To(Equal(b.Len())) + Expect(f.MinLength()).To(Equal(protocol.ByteCount(b.Len()))) }) }) diff --git a/frames/window_update_frame.go b/frames/window_update_frame.go index 8212fe87..18df65e8 100644 --- a/frames/window_update_frame.go +++ b/frames/window_update_frame.go @@ -19,7 +19,7 @@ func (f *WindowUpdateFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketN } // MinLength of a written frame -func (f *WindowUpdateFrame) MinLength() int { +func (f *WindowUpdateFrame) MinLength() protocol.ByteCount { panic("WindowUpdateFrame: Write not yet implemented") } diff --git a/packet_packer.go b/packet_packer.go index 03bc28ba..d4263f78 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -91,7 +91,7 @@ func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, con ciphertext := p.aead.Seal(currentPacketNumber, raw.Bytes(), payload) raw.Write(ciphertext) - if raw.Len() > protocol.MaxPacketSize { + if protocol.ByteCount(raw.Len()) > protocol.MaxPacketSize { panic("internal inconsistency: packet too large") } @@ -120,8 +120,8 @@ func (p *packetPacker) getPayload(frames []frames.Frame, currentPacketNumber pro // // } -func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, publicHeaderLength uint8, includeStreamFrames bool) ([]frames.Frame, error) { - payloadLength := 0 +func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, publicHeaderLength protocol.ByteCount, includeStreamFrames bool) ([]frames.Frame, error) { + payloadLength := protocol.ByteCount(0) var payloadFrames []frames.Frame // TODO: handle the case where there are more controlFrames than we can put into one packet @@ -137,7 +137,7 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra controlFrames = controlFrames[1:] } - maxFrameSize := protocol.MaxFrameAndPublicHeaderSize - int(publicHeaderLength) + maxFrameSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength if payloadLength > maxFrameSize { panic("internal inconsistency: packet payload too large") @@ -164,10 +164,10 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra if previousFrame != nil { // Don't pop the queue, leave the modified frame in frame = previousFrame - payloadLength += len(previousFrame.Data) - 1 + payloadLength += protocol.ByteCount(len(previousFrame.Data)) - 1 } else { p.streamFrameQueue.Pop() - payloadLength += len(frame.Data) - 1 + payloadLength += protocol.ByteCount(len(frame.Data)) - 1 } payloadLength += frame.MinLength() diff --git a/packet_packer_test.go b/packet_packer_test.go index 25c53116..55f427e8 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -110,11 +110,11 @@ var _ = Describe("Packet packer", func() { }) It("packs many control frames into 1 packets", func() { - publicHeaderLength := uint8(10) + publicHeaderLength := protocol.ByteCount(10) f := &frames.AckFrame{LargestObserved: 1} b := &bytes.Buffer{} f.Write(b, 3, protocol.PacketNumberLen6, 32) - maxFramesPerPacket := (protocol.MaxFrameAndPublicHeaderSize - int(publicHeaderLength)) / b.Len() + maxFramesPerPacket := int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLength) / b.Len() var controlFrames []frames.Frame for i := 0; i < maxFramesPerPacket; i++ { controlFrames = append(controlFrames, f) @@ -150,10 +150,10 @@ var _ = Describe("Packet packer", func() { Context("Stream Frame handling", func() { It("does not splits a stream frame with maximum size", func() { - publicHeaderLength := uint8(12) - maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - int(publicHeaderLength) - (1 + 4 + 8 + 2) + publicHeaderLength := protocol.ByteCount(12) + maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength - (1 + 4 + 8 + 2) f := frames.StreamFrame{ - Data: bytes.Repeat([]byte{'f'}, maxStreamFrameDataLen), + Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)), Offset: 1, } packer.AddStreamFrame(f) @@ -187,17 +187,17 @@ var _ = Describe("Packet packer", func() { }) It("splits one stream frame larger than maximum size", func() { - publicHeaderLength := uint8(5) - maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - int(publicHeaderLength) - (1 + 4 + 8 + 2) + publicHeaderLength := protocol.ByteCount(5) + maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength - (1 + 4 + 8 + 2) f := frames.StreamFrame{ - Data: bytes.Repeat([]byte{'f'}, maxStreamFrameDataLen+200), + Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200), Offset: 1, } packer.AddStreamFrame(f) payloadFrames, err := packer.composeNextPacket(nil, []frames.Frame{}, publicHeaderLength, true) Expect(err).ToNot(HaveOccurred()) Expect(len(payloadFrames)).To(Equal(1)) - Expect(len(payloadFrames[0].(*frames.StreamFrame).Data)).To(Equal(maxStreamFrameDataLen)) + Expect(protocol.ByteCount(len(payloadFrames[0].(*frames.StreamFrame).Data))).To(Equal(maxStreamFrameDataLen)) payloadFrames, err = packer.composeNextPacket(nil, []frames.Frame{}, publicHeaderLength, true) Expect(err).ToNot(HaveOccurred()) Expect(len(payloadFrames)).To(Equal(1)) @@ -208,24 +208,24 @@ var _ = Describe("Packet packer", func() { }) It("packs 2 stream frames that are too big for one packet correctly", func() { - publicHeaderLength := uint8(5) - maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - int(publicHeaderLength) - (1 + 4 + 8 + 2) + publicHeaderLength := protocol.ByteCount(5) + maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength - (1 + 4 + 8 + 2) f1 := frames.StreamFrame{ - Data: bytes.Repeat([]byte{'f'}, maxStreamFrameDataLen+100), + Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100), Offset: 1, } f2 := frames.StreamFrame{ - Data: bytes.Repeat([]byte{'f'}, maxStreamFrameDataLen+100), + Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100), Offset: 1, } packer.AddStreamFrame(f1) packer.AddStreamFrame(f2) p, err := packer.PackPacket(nil, []frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) - Expect(len(p.raw)).To(Equal(protocol.MaxPacketSize)) + Expect(protocol.ByteCount(len(p.raw))).To(Equal(protocol.MaxPacketSize)) p, err = packer.PackPacket(nil, []frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) - Expect(len(p.raw)).To(Equal(protocol.MaxPacketSize)) + Expect(protocol.ByteCount(len(p.raw))).To(Equal(protocol.MaxPacketSize)) p, err = packer.PackPacket(nil, []frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -235,22 +235,22 @@ var _ = Describe("Packet packer", func() { }) It("packs a packet that has the maximum packet size when given a large enough stream frame", func() { - publicHeaderLength := uint8(3) + publicHeaderLength := protocol.ByteCount(3) f := frames.StreamFrame{ - Data: bytes.Repeat([]byte{'f'}, protocol.MaxFrameAndPublicHeaderSize-int(publicHeaderLength)-(1+4+8+2)), + Data: bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLength-(1+4+8+2))), Offset: 1, } packer.AddStreamFrame(f) p, err := packer.PackPacket(nil, []frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(len(p.raw)).To(Equal(protocol.MaxPacketSize)) + Expect(protocol.ByteCount(len(p.raw))).To(Equal(protocol.MaxPacketSize)) }) It("splits a stream frame larger than the maximum size", func() { - publicHeaderLength := uint8(13) + publicHeaderLength := protocol.ByteCount(13) f := frames.StreamFrame{ - Data: bytes.Repeat([]byte{'f'}, protocol.MaxFrameAndPublicHeaderSize-int(publicHeaderLength)-(1+4+8+2)+1), + Data: bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLength-(1+4+8+2)+1)), Offset: 1, } packer.AddStreamFrame(f) diff --git a/protocol/protocol.go b/protocol/protocol.go index 205a94fd..a3773e8b 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -30,7 +30,7 @@ type ByteCount uint64 type ErrorCode uint32 // MaxPacketSize is the maximum packet size, including the public header -const MaxPacketSize = 1452 +const MaxPacketSize ByteCount = 1452 // MaxFrameAndPublicHeaderSize is the maximum size of a QUIC frame plus PublicHeader const MaxFrameAndPublicHeaderSize = MaxPacketSize - 1 /*private header*/ - 12 /*crypto signature*/ diff --git a/public_header.go b/public_header.go index ee214388..ec6e1579 100644 --- a/public_header.go +++ b/public_header.go @@ -147,18 +147,18 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { // GetLength gets the length of the PublicHeader in bytes // can only be called for regular packets -func (h *PublicHeader) GetLength() (uint8, error) { +func (h *PublicHeader) GetLength() (protocol.ByteCount, error) { if h.VersionFlag || h.ResetFlag { return 0, errGetLengthOnlyForRegularPackets } - length := uint8(1) // 1 byte for public flags + length := protocol.ByteCount(1) // 1 byte for public flags if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { return 0, errPacketNumberLenNotSet } if !h.TruncateConnectionID { length += 8 // 8 bytes for the connection ID } - length += uint8(h.PacketNumberLen) + length += protocol.ByteCount(h.PacketNumberLen) return length, nil } diff --git a/public_header_test.go b/public_header_test.go index 80890704..8a2c3445 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -181,7 +181,7 @@ var _ = Describe("Public Header", func() { } length, err := publicHeader.GetLength() Expect(err).ToNot(HaveOccurred()) - Expect(length).To(Equal(uint8(1 + 8 + 6))) // 1 byte public flag, 8 bytes connectionID, and packet number + Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 6))) // 1 byte public flag, 8 bytes connectionID, and packet number }) It("gets the length of a packet with longest packet number length and truncated connectionID", func() { @@ -193,7 +193,7 @@ var _ = Describe("Public Header", func() { } length, err := publicHeader.GetLength() Expect(err).ToNot(HaveOccurred()) - Expect(length).To(Equal(uint8(1 + 6))) // 1 byte public flag, and packet number + Expect(length).To(Equal(protocol.ByteCount(1 + 6))) // 1 byte public flag, and packet number }) It("gets the length of a packet 2 byte packet number length ", func() { @@ -204,7 +204,7 @@ var _ = Describe("Public Header", func() { } length, err := publicHeader.GetLength() Expect(err).ToNot(HaveOccurred()) - Expect(length).To(Equal(uint8(1 + 8 + 2))) // 1 byte public flag, 8 byte connectionID, and packet number + Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 2))) // 1 byte public flag, 8 byte connectionID, and packet number }) })