mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
simplify session by moving packet packing to separate class
This commit is contained in:
parent
e5559d37d3
commit
3b2d0efea5
4 changed files with 267 additions and 121 deletions
120
packet_packer.go
Normal file
120
packet_packer.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
type packedPacket struct {
|
||||
number protocol.PacketNumber
|
||||
entropyBit bool
|
||||
raw []byte
|
||||
payload []byte
|
||||
}
|
||||
|
||||
type packetPacker struct {
|
||||
connectionID protocol.ConnectionID
|
||||
aead crypto.AEAD
|
||||
|
||||
queuedFrames []frames.Frame
|
||||
mutex sync.Mutex
|
||||
|
||||
lastPacketNumber protocol.PacketNumber
|
||||
}
|
||||
|
||||
func (p *packetPacker) AddFrame(f frames.Frame) {
|
||||
p.mutex.Lock()
|
||||
p.queuedFrames = append(p.queuedFrames, f)
|
||||
p.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (p *packetPacker) PackPacket() (*packedPacket, error) {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock() // TODO: Split up?
|
||||
|
||||
if len(p.queuedFrames) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
payload, err := p.composeNextPayload()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entropyBit, err := utils.RandomBit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entropyBit {
|
||||
payload[0] = 1
|
||||
}
|
||||
|
||||
currentPacketNumber := protocol.PacketNumber(atomic.AddUint64(
|
||||
(*uint64)(&p.lastPacketNumber),
|
||||
1,
|
||||
))
|
||||
var raw bytes.Buffer
|
||||
responsePublicHeader := PublicHeader{
|
||||
ConnectionID: p.connectionID,
|
||||
PacketNumber: currentPacketNumber,
|
||||
}
|
||||
if err := responsePublicHeader.WritePublicHeader(&raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext := p.aead.Seal(p.lastPacketNumber, raw.Bytes(), payload)
|
||||
raw.Write(ciphertext)
|
||||
|
||||
if raw.Len() > protocol.MaxPacketSize {
|
||||
panic("internal inconsistency: packet too large")
|
||||
}
|
||||
|
||||
return &packedPacket{
|
||||
number: currentPacketNumber,
|
||||
entropyBit: entropyBit,
|
||||
raw: raw.Bytes(),
|
||||
payload: payload[1:],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *packetPacker) composeNextPayload() ([]byte, error) {
|
||||
var payload bytes.Buffer
|
||||
payload.WriteByte(0) // The entropy bit is set in sendPayload
|
||||
|
||||
for len(p.queuedFrames) > 0 {
|
||||
frame := p.queuedFrames[0]
|
||||
|
||||
if payload.Len()-1 > protocol.MaxFrameSize {
|
||||
panic("internal inconsistency: packet payload too large")
|
||||
}
|
||||
|
||||
// Does the frame fit into the remaining space?
|
||||
if payload.Len()-1+frame.MaxLength() > protocol.MaxFrameSize {
|
||||
return payload.Bytes(), nil
|
||||
}
|
||||
|
||||
if streamframe, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame {
|
||||
// Split stream frames if necessary
|
||||
previousFrame := streamframe.MaybeSplitOffFrame(protocol.MaxFrameSize - (payload.Len() - 1))
|
||||
if previousFrame != nil {
|
||||
// Don't pop the queue, leave the modified frame in
|
||||
frame = previousFrame
|
||||
} else {
|
||||
p.queuedFrames = p.queuedFrames[1:]
|
||||
}
|
||||
} else {
|
||||
p.queuedFrames = p.queuedFrames[1:]
|
||||
}
|
||||
|
||||
if err := frame.Write(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return payload.Bytes(), nil
|
||||
}
|
91
packet_packer_test.go
Normal file
91
packet_packer_test.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Packet packer", func() {
|
||||
var (
|
||||
packer *packetPacker
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
aead := &crypto.NullAEAD{}
|
||||
packer = &packetPacker{aead: aead}
|
||||
})
|
||||
|
||||
It("returns nil when no packet is queued", func() {
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("packs single packets", func() {
|
||||
f := &frames.AckFrame{}
|
||||
packer.AddFrame(f)
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := &bytes.Buffer{}
|
||||
f.Write(b)
|
||||
Expect(p.payload).To(Equal(b.Bytes()))
|
||||
Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
|
||||
})
|
||||
|
||||
It("packs multiple frames into single packet", func() {
|
||||
f1 := &frames.AckFrame{LargestObserved: 1}
|
||||
f2 := &frames.AckFrame{LargestObserved: 2}
|
||||
packer.AddFrame(f1)
|
||||
packer.AddFrame(f2)
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := &bytes.Buffer{}
|
||||
f1.Write(b)
|
||||
f2.Write(b)
|
||||
Expect(p.payload).To(Equal(b.Bytes()))
|
||||
Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
|
||||
})
|
||||
|
||||
It("packs many normal frames into 2 packets", func() {
|
||||
f := &frames.AckFrame{LargestObserved: 1}
|
||||
b := &bytes.Buffer{}
|
||||
f.Write(b)
|
||||
for i := 0; i <= (protocol.MaxFrameSize-1)/b.Len()+1; i++ {
|
||||
packer.AddFrame(f)
|
||||
}
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(p.payload) % b.Len()).To(BeZero())
|
||||
Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
|
||||
p, err = packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.payload).To(Equal(b.Bytes()))
|
||||
Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
|
||||
})
|
||||
|
||||
It("splits stream frames", func() {
|
||||
f := &frames.StreamFrame{
|
||||
Data: bytes.Repeat([]byte{'f'}, protocol.MaxFrameSize),
|
||||
Offset: 1,
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
f.Write(b)
|
||||
packer.AddFrame(f)
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(len(p.raw)).To(Equal(protocol.MaxPacketSize))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
p, err = packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
|
@ -11,3 +11,9 @@ type StreamID uint32
|
|||
|
||||
// An ErrorCode in QUIC
|
||||
type ErrorCode uint32
|
||||
|
||||
// MaxPacketSize is the maximum packet size, including the public header
|
||||
const MaxPacketSize = 1452
|
||||
|
||||
// MaxFrameSize is the maximum size of a QUIC frame
|
||||
const MaxFrameSize = MaxPacketSize - (1 + 8 + 6) /*public header*/ - 1 /*private header*/ - 12 /*crypto signature*/
|
||||
|
|
171
session.go
171
session.go
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
// StreamCallback gets a stream frame and returns a reply frame
|
||||
|
@ -24,60 +23,46 @@ type Session struct {
|
|||
VersionNumber protocol.VersionNumber
|
||||
ConnectionID protocol.ConnectionID
|
||||
|
||||
streamCallback StreamCallback
|
||||
|
||||
Connection *net.UDPConn
|
||||
CurrentRemoteAddr *net.UDPAddr
|
||||
|
||||
ServerConfig *handshake.ServerConfig
|
||||
cryptoSetup *handshake.CryptoSetup
|
||||
|
||||
EntropyReceived ackhandler.EntropyAccumulator
|
||||
EntropySent ackhandler.EntropyAccumulator
|
||||
EntropyHistory map[protocol.PacketNumber]ackhandler.EntropyAccumulator // ToDo: store this with the packet itself
|
||||
entropyHistoryMutex sync.Mutex
|
||||
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
lastObservedPacketNumber protocol.PacketNumber
|
||||
|
||||
Streams map[protocol.StreamID]*Stream
|
||||
streamsMutex sync.RWMutex
|
||||
|
||||
AckQueue []*frames.AckFrame
|
||||
outgoingAckHandler ackhandler.OutgoingPacketAckHandler
|
||||
incomingAckHandler ackhandler.IncomingPacketAckHandler
|
||||
|
||||
streamCallback StreamCallback
|
||||
packer *packetPacker
|
||||
batchMode bool
|
||||
}
|
||||
|
||||
// NewSession makes a new session
|
||||
func NewSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) *Session {
|
||||
session := &Session{
|
||||
Connection: conn,
|
||||
VersionNumber: v,
|
||||
ConnectionID: connectionID,
|
||||
ServerConfig: sCfg,
|
||||
streamCallback: streamCallback,
|
||||
lastObservedPacketNumber: 0,
|
||||
Streams: make(map[protocol.StreamID]*Stream),
|
||||
EntropyHistory: make(map[protocol.PacketNumber]ackhandler.EntropyAccumulator),
|
||||
Connection: conn,
|
||||
VersionNumber: v,
|
||||
ConnectionID: connectionID,
|
||||
ServerConfig: sCfg,
|
||||
streamCallback: streamCallback,
|
||||
Streams: make(map[protocol.StreamID]*Stream),
|
||||
}
|
||||
|
||||
cryptoStream, _ := session.NewStream(1)
|
||||
session.cryptoSetup = handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream)
|
||||
go session.cryptoSetup.HandleCryptoStream()
|
||||
|
||||
session.packer = &packetPacker{aead: session.cryptoSetup}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// HandlePacket handles a packet
|
||||
func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, publicHeader *PublicHeader, r *bytes.Reader) error {
|
||||
if s.lastObservedPacketNumber > 0 { // the first packet doesn't neccessarily need to have packetNumber 1
|
||||
if publicHeader.PacketNumber < s.lastObservedPacketNumber || publicHeader.PacketNumber > s.lastObservedPacketNumber+1 {
|
||||
return errors.New("Out of order packet")
|
||||
}
|
||||
if publicHeader.PacketNumber == s.lastObservedPacketNumber {
|
||||
return errors.New("Duplicate packet")
|
||||
}
|
||||
}
|
||||
s.lastObservedPacketNumber = publicHeader.PacketNumber
|
||||
|
||||
// TODO: Only do this after authenticating
|
||||
if addr != s.CurrentRemoteAddr {
|
||||
s.CurrentRemoteAddr = addr
|
||||
|
@ -94,12 +79,11 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.EntropyReceived.Add(publicHeader.PacketNumber, privateFlag&0x01 > 0)
|
||||
|
||||
s.queueAck(&frames.AckFrame{
|
||||
LargestObserved: publicHeader.PacketNumber,
|
||||
Entropy: s.EntropyReceived.Get(),
|
||||
})
|
||||
s.incomingAckHandler.ReceivedPacket(publicHeader.PacketNumber, privateFlag&0x01 > 0)
|
||||
|
||||
s.batchMode = true
|
||||
s.SendFrame(s.incomingAckHandler.DequeueAckFrame())
|
||||
|
||||
// read all frames in the packet
|
||||
for r.Len() > 0 {
|
||||
|
@ -147,7 +131,9 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub
|
|||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
s.batchMode = false
|
||||
return s.sendPackets()
|
||||
}
|
||||
|
||||
func (s *Session) handleStreamFrame(r *bytes.Reader) error {
|
||||
|
@ -156,11 +142,9 @@ func (s *Session) handleStreamFrame(r *bytes.Reader) error {
|
|||
return err
|
||||
}
|
||||
fmt.Printf("Got %d bytes for stream %d\n", len(frame.Data), frame.StreamID)
|
||||
|
||||
if frame.StreamID == 0 {
|
||||
return errors.New("Session: 0 is not a valid Stream ID")
|
||||
}
|
||||
|
||||
s.streamsMutex.RLock()
|
||||
stream, newStream := s.Streams[frame.StreamID]
|
||||
s.streamsMutex.RUnlock()
|
||||
|
@ -172,7 +156,6 @@ func (s *Session) handleStreamFrame(r *bytes.Reader) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !newStream {
|
||||
s.streamCallback(s, stream)
|
||||
}
|
||||
|
@ -184,26 +167,10 @@ func (s *Session) handleAckFrame(r *bytes.Reader) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.entropyHistoryMutex.Lock()
|
||||
defer s.entropyHistoryMutex.Unlock()
|
||||
expectedEntropy, ok := s.EntropyHistory[frame.LargestObserved]
|
||||
if !ok {
|
||||
return errors.New("No entropy value saved for received ACK packet")
|
||||
}
|
||||
|
||||
if byte(expectedEntropy) != frame.Entropy {
|
||||
return errors.New("Incorrect entropy value in ACK package")
|
||||
}
|
||||
|
||||
delete(s.EntropyHistory, frame.LargestObserved)
|
||||
s.outgoingAckHandler.ReceivedAck(frame)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) queueAck(f *frames.AckFrame) {
|
||||
s.AckQueue = append(s.AckQueue, f)
|
||||
}
|
||||
|
||||
func (s *Session) handleConnectionCloseFrame(r *bytes.Reader) error {
|
||||
fmt.Println("Detected CONNECTION_CLOSE")
|
||||
frame, err := frames.ParseConnectionCloseFrame(r)
|
||||
|
@ -215,10 +182,11 @@ func (s *Session) handleConnectionCloseFrame(r *bytes.Reader) error {
|
|||
}
|
||||
|
||||
func (s *Session) handleStopWaitingFrame(r *bytes.Reader, publicHeader *PublicHeader) error {
|
||||
_, err := frames.ParseStopWaitingFrame(r, publicHeader.PacketNumberLen)
|
||||
frame, err := frames.ParseStopWaitingFrame(r, publicHeader.PacketNumberLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("%#v\n", frame)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -235,83 +203,44 @@ func (s *Session) handleRstStreamFrame(r *bytes.Reader) error {
|
|||
func (s *Session) Close(e error) error {
|
||||
errorCode := protocol.ErrorCode(1)
|
||||
reasonPhrase := e.Error()
|
||||
|
||||
quicError, ok := e.(*protocol.QuicError)
|
||||
if ok {
|
||||
errorCode = quicError.ErrorCode
|
||||
}
|
||||
frame := &frames.ConnectionCloseFrame{
|
||||
return s.SendFrame(&frames.ConnectionCloseFrame{
|
||||
ErrorCode: errorCode,
|
||||
ReasonPhrase: reasonPhrase,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) sendPackets() error {
|
||||
for {
|
||||
packet, err := s.packer.PackPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if packet == nil {
|
||||
return nil
|
||||
}
|
||||
s.outgoingAckHandler.SentPacket(&ackhandler.Packet{
|
||||
PacketNumber: packet.number,
|
||||
Plaintext: packet.payload,
|
||||
EntropyBit: packet.entropyBit,
|
||||
})
|
||||
_, err = s.Connection.WriteToUDP(packet.raw, s.CurrentRemoteAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.SendFrame(frame)
|
||||
}
|
||||
|
||||
// SendFrame sends a frame to the client
|
||||
func (s *Session) SendFrame(frame frames.Frame) error {
|
||||
streamframe, ok := frame.(*frames.StreamFrame)
|
||||
if ok {
|
||||
maxlength := 1000
|
||||
if len(streamframe.Data) > maxlength {
|
||||
frame1 := &frames.StreamFrame{
|
||||
StreamID: streamframe.StreamID,
|
||||
Offset: streamframe.Offset,
|
||||
Data: streamframe.Data[:maxlength],
|
||||
}
|
||||
frame2 := &frames.StreamFrame{
|
||||
StreamID: streamframe.StreamID,
|
||||
Offset: streamframe.Offset + uint64(maxlength),
|
||||
Data: streamframe.Data[maxlength:],
|
||||
FinBit: streamframe.FinBit,
|
||||
}
|
||||
err := s.SendFrame(frame1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.SendFrame(frame2)
|
||||
}
|
||||
s.packer.AddFrame(frame)
|
||||
if s.batchMode {
|
||||
return nil
|
||||
}
|
||||
|
||||
var framesData bytes.Buffer
|
||||
entropyBit, err := utils.RandomBit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if entropyBit {
|
||||
framesData.WriteByte(1)
|
||||
} else {
|
||||
framesData.WriteByte(0)
|
||||
}
|
||||
|
||||
// add all outstanding ACKs
|
||||
for _, ackFrame := range s.AckQueue {
|
||||
ackFrame.Write(&framesData)
|
||||
}
|
||||
s.AckQueue = s.AckQueue[:0]
|
||||
|
||||
if err := frame.Write(&framesData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.lastSentPacketNumber++
|
||||
|
||||
var fullReply bytes.Buffer
|
||||
packetNumber := s.lastSentPacketNumber
|
||||
responsePublicHeader := PublicHeader{ConnectionID: s.ConnectionID, PacketNumber: packetNumber}
|
||||
if err := responsePublicHeader.WritePublicHeader(&fullReply); err != nil {
|
||||
return err
|
||||
}
|
||||
s.EntropySent.Add(packetNumber, entropyBit)
|
||||
s.entropyHistoryMutex.Lock()
|
||||
defer s.entropyHistoryMutex.Unlock()
|
||||
s.EntropyHistory[packetNumber] = s.EntropySent
|
||||
|
||||
ciphertext := s.cryptoSetup.Seal(s.lastSentPacketNumber, fullReply.Bytes(), framesData.Bytes())
|
||||
fullReply.Write(ciphertext)
|
||||
|
||||
fmt.Printf("-> Sending packet %d (%d bytes) to %v\n", responsePublicHeader.PacketNumber, len(fullReply.Bytes()), s.CurrentRemoteAddr)
|
||||
_, err = s.Connection.WriteToUDP(fullReply.Bytes(), s.CurrentRemoteAddr)
|
||||
return err
|
||||
return s.sendPackets()
|
||||
}
|
||||
|
||||
// NewStream creates a new strean open for reading and writing
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue