use *receivedPacket thoughout the session

This commit is contained in:
Lucas Clemente 2016-09-07 14:10:32 +02:00
parent e3c90c181a
commit 099545521f
5 changed files with 33 additions and 29 deletions

View file

@ -43,7 +43,7 @@ func newLinkedConnection(other *Session) *linkedConnection {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
hdr.Raw = packet[:len(packet)-r.Len()] hdr.Raw = packet[:len(packet)-r.Len()]
conn.other.handlePacket(nil, hdr, packet[len(packet)-r.Len():]) conn.other.handlePacket(&receivedPacket{publicHeader: hdr, data: packet[len(packet)-r.Len():]})
} }
}() }()
return conn return conn

View file

@ -16,7 +16,7 @@ import (
// packetHandler handles packets // packetHandler handles packets
type packetHandler interface { type packetHandler interface {
handlePacket(addr interface{}, hdr *PublicHeader, data []byte) handlePacket(*receivedPacket)
run() run()
Close(error) error Close(error) error
} }
@ -171,7 +171,11 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
// Late packet for closed session // Late packet for closed session
return nil return nil
} }
session.handlePacket(remoteAddr, hdr, packet[len(packet)-r.Len():]) session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
publicHeader: hdr,
data: packet[len(packet)-r.Len():],
})
return nil return nil
} }

View file

@ -20,7 +20,7 @@ type mockSession struct {
closed bool closed bool
} }
func (s *mockSession) handlePacket(addr interface{}, hdr *PublicHeader, data []byte) { func (s *mockSession) handlePacket(*receivedPacket) {
s.packetCount++ s.packetCount++
} }

View file

@ -61,14 +61,14 @@ type Session struct {
cryptoSetup *handshake.CryptoSetup cryptoSetup *handshake.CryptoSetup
receivedPackets chan receivedPacket receivedPackets chan *receivedPacket
sendingScheduled chan struct{} sendingScheduled chan struct{}
// closeChan is used to notify the run loop that it should terminate. // closeChan is used to notify the run loop that it should terminate.
// If the value is not nil, the error is sent as a CONNECTION_CLOSE. // If the value is not nil, the error is sent as a CONNECTION_CLOSE.
closeChan chan *qerr.QuicError closeChan chan *qerr.QuicError
closed uint32 // atomic bool closed uint32 // atomic bool
undecryptablePackets []receivedPacket undecryptablePackets []*receivedPacket
aeadChanged chan struct{} aeadChanged chan struct{}
delayedAckOriginTime time.Time delayedAckOriginTime time.Time
@ -107,11 +107,11 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
sentPacketHandler: sentPacketHandler, sentPacketHandler: sentPacketHandler,
receivedPacketHandler: receivedPacketHandler, receivedPacketHandler: receivedPacketHandler,
flowControlManager: flowControlManager, flowControlManager: flowControlManager,
receivedPackets: make(chan receivedPacket, protocol.MaxSessionUnprocessedPackets), receivedPackets: make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets),
closeChan: make(chan *qerr.QuicError, 1), closeChan: make(chan *qerr.QuicError, 1),
sendingScheduled: make(chan struct{}, 1), sendingScheduled: make(chan struct{}, 1),
connectionParametersManager: connectionParametersManager, connectionParametersManager: connectionParametersManager,
undecryptablePackets: make([]receivedPacket, 0, protocol.MaxUndecryptablePackets), undecryptablePackets: make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets),
aeadChanged: make(chan struct{}, 1), aeadChanged: make(chan struct{}, 1),
timer: time.NewTimer(0), timer: time.NewTimer(0),
lastNetworkActivityTime: time.Now(), lastNetworkActivityTime: time.Now(),
@ -170,7 +170,7 @@ func (s *Session) run() {
// We do all the interesting stuff after the switch statement, so // We do all the interesting stuff after the switch statement, so
// nothing to see here. // nothing to see here.
case p := <-s.receivedPackets: case p := <-s.receivedPackets:
err = s.handlePacketImpl(p.remoteAddr, p.publicHeader, p.data) err = s.handlePacketImpl(p)
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure { if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
s.tryQueueingUndecryptablePacket(p) s.tryQueueingUndecryptablePacket(p)
continue continue
@ -225,8 +225,10 @@ func (s *Session) maybeResetTimer() {
s.currentDeadline = nextDeadline s.currentDeadline = nextDeadline
} }
func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *PublicHeader, data []byte) error { func (s *Session) handlePacketImpl(p *receivedPacket) error {
s.lastNetworkActivityTime = time.Now() s.lastNetworkActivityTime = time.Now()
hdr := p.publicHeader
data := p.data
// Calculate packet number // Calculate packet number
hdr.PacketNumber = protocol.InferPacketNumber( hdr.PacketNumber = protocol.InferPacketNumber(
@ -239,7 +241,7 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *PublicHeader, da
} }
// TODO: Only do this after authenticating // TODO: Only do this after authenticating
s.conn.setCurrentRemoteAddr(remoteAddr) s.conn.setCurrentRemoteAddr(p.remoteAddr)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
if err != nil { if err != nil {
@ -312,12 +314,12 @@ func (s *Session) handleFrames(fs []frames.Frame) error {
return nil return nil
} }
// handlePacket handles a packet // handlePacket is called by the server with a new packet
func (s *Session) handlePacket(remoteAddr interface{}, hdr *PublicHeader, data []byte) { func (s *Session) handlePacket(p *receivedPacket) {
// Discard packets once the amount of queued packets is larger than // Discard packets once the amount of queued packets is larger than
// the channel size, protocol.MaxSessionUnprocessedPackets // the channel size, protocol.MaxSessionUnprocessedPackets
select { select {
case s.receivedPackets <- receivedPacket{remoteAddr: remoteAddr, publicHeader: hdr, data: data}: case s.receivedPackets <- p:
default: default:
} }
} }
@ -611,7 +613,7 @@ func (s *Session) scheduleSending() {
} }
} }
func (s *Session) tryQueueingUndecryptablePacket(p receivedPacket) { func (s *Session) tryQueueingUndecryptablePacket(p *receivedPacket) {
if s.cryptoSetup.HandshakeComplete() { if s.cryptoSetup.HandshakeComplete() {
return return
} }
@ -624,7 +626,7 @@ func (s *Session) tryQueueingUndecryptablePacket(p receivedPacket) {
func (s *Session) tryDecryptingQueuedPackets() { func (s *Session) tryDecryptingQueuedPackets() {
for _, p := range s.undecryptablePackets { for _, p := range s.undecryptablePackets {
s.handlePacket(p.remoteAddr, p.publicHeader, p.data) s.handlePacket(p)
} }
s.undecryptablePackets = s.undecryptablePackets[:0] s.undecryptablePackets = s.undecryptablePackets[:0]
} }

View file

@ -459,7 +459,7 @@ var _ = Describe("Session", func() {
It("sets the {last,largest}RcvdPacketNumber", func() { It("sets the {last,largest}RcvdPacketNumber", func() {
hdr.PacketNumber = 5 hdr.PacketNumber = 5
err := session.handlePacketImpl(nil, hdr, nil) err := session.handlePacketImpl(&receivedPacket{publicHeader: hdr})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5)))
Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5)))
@ -467,12 +467,12 @@ var _ = Describe("Session", func() {
It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() { It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() {
hdr.PacketNumber = 5 hdr.PacketNumber = 5
err := session.handlePacketImpl(nil, hdr, nil) err := session.handlePacketImpl(&receivedPacket{publicHeader: hdr})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5)))
Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5)))
hdr.PacketNumber = 3 hdr.PacketNumber = 3
err = session.handlePacketImpl(nil, hdr, nil) err = session.handlePacketImpl(&receivedPacket{publicHeader: hdr})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(3))) Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(3)))
Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5)))
@ -480,9 +480,9 @@ var _ = Describe("Session", func() {
It("ignores duplicate packets", func() { It("ignores duplicate packets", func() {
hdr.PacketNumber = 5 hdr.PacketNumber = 5
err := session.handlePacketImpl(nil, hdr, nil) err := session.handlePacketImpl(&receivedPacket{publicHeader: hdr})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = session.handlePacketImpl(nil, hdr, nil) err = session.handlePacketImpl(&receivedPacket{publicHeader: hdr})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -490,7 +490,7 @@ var _ = Describe("Session", func() {
err := session.receivedPacketHandler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10}) err := session.receivedPacketHandler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
hdr.PacketNumber = 5 hdr.PacketNumber = 5
err = session.handlePacketImpl(nil, hdr, nil) err = session.handlePacketImpl(&receivedPacket{publicHeader: hdr})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
}) })
@ -717,7 +717,7 @@ var _ = Describe("Session", func() {
hdr := &PublicHeader{ hdr := &PublicHeader{
PacketNumber: protocol.PacketNumber(i + 1), PacketNumber: protocol.PacketNumber(i + 1),
} }
session.handlePacket(nil, hdr, []byte("foobar")) session.handlePacket(&receivedPacket{publicHeader: hdr, data: []byte("foobar")})
} }
session.run() session.run()
@ -731,7 +731,7 @@ var _ = Describe("Session", func() {
hdr := &PublicHeader{ hdr := &PublicHeader{
PacketNumber: protocol.PacketNumber(i + 1), PacketNumber: protocol.PacketNumber(i + 1),
} }
session.handlePacket(nil, hdr, []byte("foobar")) session.handlePacket(&receivedPacket{publicHeader: hdr, data: []byte("foobar")})
} }
go session.run() go session.run()
Consistently(session.undecryptablePackets).Should(HaveLen(0)) Consistently(session.undecryptablePackets).Should(HaveLen(0))
@ -739,10 +739,8 @@ var _ = Describe("Session", func() {
}) })
It("unqueues undecryptable packets for later decryption", func() { It("unqueues undecryptable packets for later decryption", func() {
session.undecryptablePackets = []receivedPacket{{ session.undecryptablePackets = []*receivedPacket{{
nil, publicHeader: &PublicHeader{PacketNumber: protocol.PacketNumber(42)},
&PublicHeader{PacketNumber: protocol.PacketNumber(42)},
nil,
}} }}
Expect(session.receivedPackets).NotTo(Receive()) Expect(session.receivedPackets).NotTo(Receive())
session.tryDecryptingQueuedPackets() session.tryDecryptingQueuedPackets()
@ -775,7 +773,7 @@ var _ = Describe("Session", func() {
It("stores up to MaxSessionUnprocessedPackets packets", func(done Done) { It("stores up to MaxSessionUnprocessedPackets packets", func(done Done) {
// Nothing here should block // Nothing here should block
for i := protocol.PacketNumber(0); i < protocol.MaxSessionUnprocessedPackets+10; i++ { for i := protocol.PacketNumber(0); i < protocol.MaxSessionUnprocessedPackets+10; i++ {
session.handlePacket(nil, nil, nil) session.handlePacket(&receivedPacket{})
} }
close(done) close(done)
}, 0.5) }, 0.5)