add methods to add and remove reset tokens to the packet handler map

This commit is contained in:
Marten Seemann 2019-03-05 14:16:07 +09:00
parent 733dcb75eb
commit dd8c590b13
2 changed files with 28 additions and 53 deletions

View file

@ -11,11 +11,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
type packetHandlerEntry struct {
handler packetHandler
resetToken *[16]byte
}
// The packetHandlerMap stores packetHandlers, identified by connection ID.
// It is used:
// * by the server to store sessions
@ -26,7 +21,7 @@ type packetHandlerMap struct {
conn net.PacketConn
connIDLen int
handlers map[string] /* string(ConnectionID)*/ packetHandlerEntry
handlers map[string] /* string(ConnectionID)*/ packetHandler
resetTokens map[[16]byte] /* stateless reset token */ packetHandler
server unknownPacketHandler
@ -45,7 +40,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger
conn: conn,
connIDLen: connIDLen,
listening: make(chan struct{}),
handlers: make(map[string]packetHandlerEntry),
handlers: make(map[string]packetHandler),
resetTokens: make(map[[16]byte]packetHandler),
deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
logger: logger,
@ -56,14 +51,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
h.mutex.Lock()
h.handlers[string(id)] = packetHandlerEntry{handler: handler}
h.mutex.Unlock()
}
func (h *packetHandlerMap) AddWithResetToken(id protocol.ConnectionID, handler packetHandler, token [16]byte) {
h.mutex.Lock()
h.handlers[string(id)] = packetHandlerEntry{handler: handler, resetToken: &token}
h.resetTokens[token] = handler
h.handlers[string(id)] = handler
h.mutex.Unlock()
}
@ -73,12 +61,7 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
h.mutex.Lock()
if handlerEntry, ok := h.handlers[id]; ok {
if token := handlerEntry.resetToken; token != nil {
delete(h.resetTokens, *token)
}
delete(h.handlers, id)
}
delete(h.handlers, id)
h.mutex.Unlock()
}
@ -92,6 +75,18 @@ func (h *packetHandlerMap) retireByConnectionIDAsString(id string) {
})
}
func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler) {
h.mutex.Lock()
h.resetTokens[token] = handler
h.mutex.Unlock()
}
func (h *packetHandlerMap) RemoveResetToken(token [16]byte) {
h.mutex.Lock()
delete(h.resetTokens, token)
h.mutex.Unlock()
}
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
h.mutex.Lock()
h.server = s
@ -102,8 +97,7 @@ func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
h.server = nil
var wg sync.WaitGroup
for id, handlerEntry := range h.handlers {
handler := handlerEntry.handler
for id, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(id string, handler packetHandler) {
@ -136,12 +130,12 @@ func (h *packetHandlerMap) close(e error) error {
h.closed = true
var wg sync.WaitGroup
for _, handlerEntry := range h.handlers {
for _, handler := range h.handlers {
wg.Add(1)
go func(handlerEntry packetHandlerEntry) {
handlerEntry.handler.destroy(e)
go func(handler packetHandler) {
handler.destroy(e)
wg.Done()
}(handlerEntry)
}(handler)
}
if h.server != nil {
@ -187,7 +181,7 @@ func (h *packetHandlerMap) handlePacket(
return
}
handlerEntry, handlerFound := h.handlers[string(connID)]
handler, handlerFound := h.handlers[string(connID)]
p := &receivedPacket{
remoteAddr: addr,
@ -196,7 +190,7 @@ func (h *packetHandlerMap) handlePacket(
data: data,
}
if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
handler.handlePacket(p)
return
}
if data[0]&0x80 == 0 {

View file

@ -163,28 +163,10 @@ var _ = Describe("Packet Handler Map", func() {
})
Context("stateless reset handling", func() {
It("handles packets for connections added with a reset token", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddWithResetToken(connID, packetHandler, token)
// first send a normal packet
handledPacket := make(chan struct{})
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
cid, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(cid).To(Equal(connID))
close(handledPacket)
})
conn.dataToRead <- getPacket(connID)
Eventually(handledPacket).Should(BeClosed())
})
It("handles stateless resets", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddWithResetToken(connID, packetHandler, token)
handler.AddResetToken(token, packetHandler)
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
destroyed := make(chan struct{})
@ -199,7 +181,7 @@ var _ = Describe("Packet Handler Map", func() {
handler.connIDLen = 0
packetHandler := NewMockPacketHandler(mockCtrl)
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddWithResetToken(protocol.ConnectionID{}, packetHandler, token)
handler.AddResetToken(token, packetHandler)
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
destroyed := make(chan struct{})
@ -210,13 +192,12 @@ var _ = Describe("Packet Handler Map", func() {
Eventually(destroyed).Should(BeClosed())
})
It("deletes reset tokens when the session is retired", func() {
It("deletes reset tokens", func() {
handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42}
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token)
handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond))
handler.AddResetToken(token, NewMockPacketHandler(mockCtrl))
handler.RemoveResetToken(token)
handler.handlePacket(nil, nil, getPacket(connID))
// don't EXPECT any calls to handlePacket of the MockPacketHandler
packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...)