diff --git a/packet_packer.go b/packet_packer.go new file mode 100644 index 00000000..46eb91e6 --- /dev/null +++ b/packet_packer.go @@ -0,0 +1,120 @@ +package quic + +import ( + "bytes" + "sync" + "sync/atomic" + + "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/utils" +) + +type packedPacket struct { + number protocol.PacketNumber + entropyBit bool + raw []byte + payload []byte +} + +type packetPacker struct { + connectionID protocol.ConnectionID + aead crypto.AEAD + + queuedFrames []frames.Frame + mutex sync.Mutex + + lastPacketNumber protocol.PacketNumber +} + +func (p *packetPacker) AddFrame(f frames.Frame) { + p.mutex.Lock() + p.queuedFrames = append(p.queuedFrames, f) + p.mutex.Unlock() +} + +func (p *packetPacker) PackPacket() (*packedPacket, error) { + p.mutex.Lock() + defer p.mutex.Unlock() // TODO: Split up? + + if len(p.queuedFrames) == 0 { + return nil, nil + } + + payload, err := p.composeNextPayload() + if err != nil { + return nil, err + } + + entropyBit, err := utils.RandomBit() + if err != nil { + return nil, err + } + if entropyBit { + payload[0] = 1 + } + + currentPacketNumber := protocol.PacketNumber(atomic.AddUint64( + (*uint64)(&p.lastPacketNumber), + 1, + )) + var raw bytes.Buffer + responsePublicHeader := PublicHeader{ + ConnectionID: p.connectionID, + PacketNumber: currentPacketNumber, + } + if err := responsePublicHeader.WritePublicHeader(&raw); err != nil { + return nil, err + } + + ciphertext := p.aead.Seal(p.lastPacketNumber, raw.Bytes(), payload) + raw.Write(ciphertext) + + if raw.Len() > protocol.MaxPacketSize { + panic("internal inconsistency: packet too large") + } + + return &packedPacket{ + number: currentPacketNumber, + entropyBit: entropyBit, + raw: raw.Bytes(), + payload: payload[1:], + }, nil +} + +func (p *packetPacker) composeNextPayload() ([]byte, error) { + var payload bytes.Buffer + payload.WriteByte(0) // The entropy bit is set in sendPayload + + for len(p.queuedFrames) > 0 { + frame := p.queuedFrames[0] + + if payload.Len()-1 > protocol.MaxFrameSize { + panic("internal inconsistency: packet payload too large") + } + + // Does the frame fit into the remaining space? + if payload.Len()-1+frame.MaxLength() > protocol.MaxFrameSize { + return payload.Bytes(), nil + } + + if streamframe, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame { + // Split stream frames if necessary + previousFrame := streamframe.MaybeSplitOffFrame(protocol.MaxFrameSize - (payload.Len() - 1)) + if previousFrame != nil { + // Don't pop the queue, leave the modified frame in + frame = previousFrame + } else { + p.queuedFrames = p.queuedFrames[1:] + } + } else { + p.queuedFrames = p.queuedFrames[1:] + } + + if err := frame.Write(&payload); err != nil { + return nil, err + } + } + return payload.Bytes(), nil +} diff --git a/packet_packer_test.go b/packet_packer_test.go new file mode 100644 index 00000000..548a5d6a --- /dev/null +++ b/packet_packer_test.go @@ -0,0 +1,91 @@ +package quic + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Packet packer", func() { + var ( + packer *packetPacker + ) + + BeforeEach(func() { + aead := &crypto.NullAEAD{} + packer = &packetPacker{aead: aead} + }) + + It("returns nil when no packet is queued", func() { + p, err := packer.PackPacket() + Expect(p).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("packs single packets", func() { + f := &frames.AckFrame{} + packer.AddFrame(f) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + b := &bytes.Buffer{} + f.Write(b) + Expect(p.payload).To(Equal(b.Bytes())) + Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) + }) + + It("packs multiple frames into single packet", func() { + f1 := &frames.AckFrame{LargestObserved: 1} + f2 := &frames.AckFrame{LargestObserved: 2} + packer.AddFrame(f1) + packer.AddFrame(f2) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + b := &bytes.Buffer{} + f1.Write(b) + f2.Write(b) + Expect(p.payload).To(Equal(b.Bytes())) + Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) + }) + + It("packs many normal frames into 2 packets", func() { + f := &frames.AckFrame{LargestObserved: 1} + b := &bytes.Buffer{} + f.Write(b) + for i := 0; i <= (protocol.MaxFrameSize-1)/b.Len()+1; i++ { + packer.AddFrame(f) + } + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(len(p.payload) % b.Len()).To(BeZero()) + Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) + p, err = packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.payload).To(Equal(b.Bytes())) + Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) + }) + + It("splits stream frames", func() { + f := &frames.StreamFrame{ + Data: bytes.Repeat([]byte{'f'}, protocol.MaxFrameSize), + Offset: 1, + } + b := &bytes.Buffer{} + f.Write(b) + packer.AddFrame(f) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(len(p.raw)).To(Equal(protocol.MaxPacketSize)) + Expect(err).ToNot(HaveOccurred()) + p, err = packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + }) +}) diff --git a/protocol/protocol.go b/protocol/protocol.go index d666f43a..4bb0dddc 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -11,3 +11,9 @@ type StreamID uint32 // An ErrorCode in QUIC type ErrorCode uint32 + +// MaxPacketSize is the maximum packet size, including the public header +const MaxPacketSize = 1452 + +// MaxFrameSize is the maximum size of a QUIC frame +const MaxFrameSize = MaxPacketSize - (1 + 8 + 6) /*public header*/ - 1 /*private header*/ - 12 /*crypto signature*/ diff --git a/session.go b/session.go index dd7eec5b..0c61519d 100644 --- a/session.go +++ b/session.go @@ -13,7 +13,6 @@ import ( "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/utils" ) // StreamCallback gets a stream frame and returns a reply frame @@ -24,60 +23,46 @@ type Session struct { VersionNumber protocol.VersionNumber ConnectionID protocol.ConnectionID + streamCallback StreamCallback + Connection *net.UDPConn CurrentRemoteAddr *net.UDPAddr ServerConfig *handshake.ServerConfig cryptoSetup *handshake.CryptoSetup - EntropyReceived ackhandler.EntropyAccumulator - EntropySent ackhandler.EntropyAccumulator - EntropyHistory map[protocol.PacketNumber]ackhandler.EntropyAccumulator // ToDo: store this with the packet itself - entropyHistoryMutex sync.Mutex - - lastSentPacketNumber protocol.PacketNumber - lastObservedPacketNumber protocol.PacketNumber - Streams map[protocol.StreamID]*Stream streamsMutex sync.RWMutex - AckQueue []*frames.AckFrame + outgoingAckHandler ackhandler.OutgoingPacketAckHandler + incomingAckHandler ackhandler.IncomingPacketAckHandler - streamCallback StreamCallback + packer *packetPacker + batchMode bool } // NewSession makes a new session func NewSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) *Session { session := &Session{ - Connection: conn, - VersionNumber: v, - ConnectionID: connectionID, - ServerConfig: sCfg, - streamCallback: streamCallback, - lastObservedPacketNumber: 0, - Streams: make(map[protocol.StreamID]*Stream), - EntropyHistory: make(map[protocol.PacketNumber]ackhandler.EntropyAccumulator), + Connection: conn, + VersionNumber: v, + ConnectionID: connectionID, + ServerConfig: sCfg, + streamCallback: streamCallback, + Streams: make(map[protocol.StreamID]*Stream), } cryptoStream, _ := session.NewStream(1) session.cryptoSetup = handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream) go session.cryptoSetup.HandleCryptoStream() + session.packer = &packetPacker{aead: session.cryptoSetup} + return session } // HandlePacket handles a packet func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, publicHeader *PublicHeader, r *bytes.Reader) error { - if s.lastObservedPacketNumber > 0 { // the first packet doesn't neccessarily need to have packetNumber 1 - if publicHeader.PacketNumber < s.lastObservedPacketNumber || publicHeader.PacketNumber > s.lastObservedPacketNumber+1 { - return errors.New("Out of order packet") - } - if publicHeader.PacketNumber == s.lastObservedPacketNumber { - return errors.New("Duplicate packet") - } - } - s.lastObservedPacketNumber = publicHeader.PacketNumber - // TODO: Only do this after authenticating if addr != s.CurrentRemoteAddr { s.CurrentRemoteAddr = addr @@ -94,12 +79,11 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub if err != nil { return err } - s.EntropyReceived.Add(publicHeader.PacketNumber, privateFlag&0x01 > 0) - s.queueAck(&frames.AckFrame{ - LargestObserved: publicHeader.PacketNumber, - Entropy: s.EntropyReceived.Get(), - }) + s.incomingAckHandler.ReceivedPacket(publicHeader.PacketNumber, privateFlag&0x01 > 0) + + s.batchMode = true + s.SendFrame(s.incomingAckHandler.DequeueAckFrame()) // read all frames in the packet for r.Len() > 0 { @@ -147,7 +131,9 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub } } } - return nil + + s.batchMode = false + return s.sendPackets() } func (s *Session) handleStreamFrame(r *bytes.Reader) error { @@ -156,11 +142,9 @@ func (s *Session) handleStreamFrame(r *bytes.Reader) error { return err } fmt.Printf("Got %d bytes for stream %d\n", len(frame.Data), frame.StreamID) - if frame.StreamID == 0 { return errors.New("Session: 0 is not a valid Stream ID") } - s.streamsMutex.RLock() stream, newStream := s.Streams[frame.StreamID] s.streamsMutex.RUnlock() @@ -172,7 +156,6 @@ func (s *Session) handleStreamFrame(r *bytes.Reader) error { if err != nil { return err } - if !newStream { s.streamCallback(s, stream) } @@ -184,26 +167,10 @@ func (s *Session) handleAckFrame(r *bytes.Reader) error { if err != nil { return err } - - s.entropyHistoryMutex.Lock() - defer s.entropyHistoryMutex.Unlock() - expectedEntropy, ok := s.EntropyHistory[frame.LargestObserved] - if !ok { - return errors.New("No entropy value saved for received ACK packet") - } - - if byte(expectedEntropy) != frame.Entropy { - return errors.New("Incorrect entropy value in ACK package") - } - - delete(s.EntropyHistory, frame.LargestObserved) + s.outgoingAckHandler.ReceivedAck(frame) return nil } -func (s *Session) queueAck(f *frames.AckFrame) { - s.AckQueue = append(s.AckQueue, f) -} - func (s *Session) handleConnectionCloseFrame(r *bytes.Reader) error { fmt.Println("Detected CONNECTION_CLOSE") frame, err := frames.ParseConnectionCloseFrame(r) @@ -215,10 +182,11 @@ func (s *Session) handleConnectionCloseFrame(r *bytes.Reader) error { } func (s *Session) handleStopWaitingFrame(r *bytes.Reader, publicHeader *PublicHeader) error { - _, err := frames.ParseStopWaitingFrame(r, publicHeader.PacketNumberLen) + frame, err := frames.ParseStopWaitingFrame(r, publicHeader.PacketNumberLen) if err != nil { return err } + fmt.Printf("%#v\n", frame) return nil } @@ -235,83 +203,44 @@ func (s *Session) handleRstStreamFrame(r *bytes.Reader) error { func (s *Session) Close(e error) error { errorCode := protocol.ErrorCode(1) reasonPhrase := e.Error() - quicError, ok := e.(*protocol.QuicError) if ok { errorCode = quicError.ErrorCode } - frame := &frames.ConnectionCloseFrame{ + return s.SendFrame(&frames.ConnectionCloseFrame{ ErrorCode: errorCode, ReasonPhrase: reasonPhrase, + }) +} + +func (s *Session) sendPackets() error { + for { + packet, err := s.packer.PackPacket() + if err != nil { + return err + } + if packet == nil { + return nil + } + s.outgoingAckHandler.SentPacket(&ackhandler.Packet{ + PacketNumber: packet.number, + Plaintext: packet.payload, + EntropyBit: packet.entropyBit, + }) + _, err = s.Connection.WriteToUDP(packet.raw, s.CurrentRemoteAddr) + if err != nil { + return err + } } - return s.SendFrame(frame) } // SendFrame sends a frame to the client func (s *Session) SendFrame(frame frames.Frame) error { - streamframe, ok := frame.(*frames.StreamFrame) - if ok { - maxlength := 1000 - if len(streamframe.Data) > maxlength { - frame1 := &frames.StreamFrame{ - StreamID: streamframe.StreamID, - Offset: streamframe.Offset, - Data: streamframe.Data[:maxlength], - } - frame2 := &frames.StreamFrame{ - StreamID: streamframe.StreamID, - Offset: streamframe.Offset + uint64(maxlength), - Data: streamframe.Data[maxlength:], - FinBit: streamframe.FinBit, - } - err := s.SendFrame(frame1) - if err != nil { - return err - } - return s.SendFrame(frame2) - } + s.packer.AddFrame(frame) + if s.batchMode { + return nil } - - var framesData bytes.Buffer - entropyBit, err := utils.RandomBit() - if err != nil { - return err - } - if entropyBit { - framesData.WriteByte(1) - } else { - framesData.WriteByte(0) - } - - // add all outstanding ACKs - for _, ackFrame := range s.AckQueue { - ackFrame.Write(&framesData) - } - s.AckQueue = s.AckQueue[:0] - - if err := frame.Write(&framesData); err != nil { - return err - } - - s.lastSentPacketNumber++ - - var fullReply bytes.Buffer - packetNumber := s.lastSentPacketNumber - responsePublicHeader := PublicHeader{ConnectionID: s.ConnectionID, PacketNumber: packetNumber} - if err := responsePublicHeader.WritePublicHeader(&fullReply); err != nil { - return err - } - s.EntropySent.Add(packetNumber, entropyBit) - s.entropyHistoryMutex.Lock() - defer s.entropyHistoryMutex.Unlock() - s.EntropyHistory[packetNumber] = s.EntropySent - - ciphertext := s.cryptoSetup.Seal(s.lastSentPacketNumber, fullReply.Bytes(), framesData.Bytes()) - fullReply.Write(ciphertext) - - fmt.Printf("-> Sending packet %d (%d bytes) to %v\n", responsePublicHeader.PacketNumber, len(fullReply.Bytes()), s.CurrentRemoteAddr) - _, err = s.Connection.WriteToUDP(fullReply.Bytes(), s.CurrentRemoteAddr) - return err + return s.sendPackets() } // NewStream creates a new strean open for reading and writing