send Version Negotiation packets in a separate Go routine

This commit is contained in:
Marten Seemann 2018-11-27 09:02:16 +07:00
parent 6c4116e7f4
commit 5e01c49fdf
3 changed files with 49 additions and 38 deletions

View file

@ -1,7 +1,6 @@
package quic package quic
import ( import (
"bytes"
"errors" "errors"
"net" "net"
"time" "time"
@ -10,19 +9,24 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
type mockPacketConnWrite struct {
data []byte
to net.Addr
}
type mockPacketConn struct { type mockPacketConn struct {
addr net.Addr addr net.Addr
dataToRead chan []byte dataToRead chan []byte
dataReadFrom net.Addr dataReadFrom net.Addr
readErr error readErr error
dataWritten bytes.Buffer dataWritten chan mockPacketConnWrite
dataWrittenTo net.Addr closed bool
closed bool
} }
func newMockPacketConn() *mockPacketConn { func newMockPacketConn() *mockPacketConn {
return &mockPacketConn{ return &mockPacketConn{
dataToRead: make(chan []byte, 1000), dataToRead: make(chan []byte, 1000),
dataWritten: make(chan mockPacketConnWrite, 1000),
} }
} }
@ -37,10 +41,16 @@ func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
n := copy(b, data) n := copy(b, data)
return n, c.dataReadFrom, nil return n, c.dataReadFrom, nil
} }
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
c.dataWrittenTo = addr select {
return c.dataWritten.Write(b) case c.dataWritten <- mockPacketConnWrite{to: addr, data: b}:
return len(b), nil
default:
panic("channel full")
}
} }
func (c *mockPacketConn) Close() error { func (c *mockPacketConn) Close() error {
if !c.closed { if !c.closed {
close(c.dataToRead) close(c.dataToRead)
@ -72,10 +82,11 @@ var _ = Describe("Connection", func() {
}) })
It("writes", func() { It("writes", func() {
err := c.Write([]byte("foobar")) Expect(c.Write([]byte("foobar"))).To(Succeed())
Expect(err).ToNot(HaveOccurred()) var write mockPacketConnWrite
Expect(packetConn.dataWritten.Bytes()).To(Equal([]byte("foobar"))) Expect(packetConn.dataWritten).To(Receive(&write))
Expect(packetConn.dataWrittenTo.String()).To(Equal("192.168.100.200:1337")) Expect(write.to.String()).To(Equal("192.168.100.200:1337"))
Expect(write.data).To(Equal([]byte("foobar")))
}) })
It("reads", func() { It("reads", func() {

View file

@ -299,23 +299,17 @@ func (s *server) Addr() net.Addr {
} }
func (s *server) handlePacket(p *receivedPacket) { func (s *server) handlePacket(p *receivedPacket) {
if err := s.handlePacketImpl(p); err != nil {
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
}
}
func (s *server) handlePacketImpl(p *receivedPacket) error {
hdr := p.hdr hdr := p.hdr
// send a Version Negotiation Packet if the client is speaking a different protocol version // send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
return s.sendVersionNegotiationPacket(p) go s.sendVersionNegotiationPacket(p)
return
} }
if hdr.Type == protocol.PacketTypeInitial { if hdr.Type == protocol.PacketTypeInitial {
go s.handleInitial(p) go s.handleInitial(p)
} }
// TODO(#943): send Stateless Reset // TODO(#943): send Stateless Reset
return nil
} }
func (s *server) handleInitial(p *receivedPacket) { func (s *server) handleInitial(p *receivedPacket) {
@ -450,14 +444,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
return nil return nil
} }
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error { func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
hdr := p.hdr hdr := p.hdr
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
if err != nil { if err != nil {
return err s.logger.Debugf("Error composing Version Negotiation: %s", err)
return
}
if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)
} }
_, err = s.conn.WriteTo(data, p.remoteAddr)
return err
} }

View file

@ -113,7 +113,7 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0], Version: serv.config.Versions[0],
}, },
}) })
Expect(conn.dataWritten.Len()).To(BeZero()) Consistently(conn.dataWritten).ShouldNot(Receive())
}) })
It("drops too small Initial", func() { It("drops too small Initial", func() {
@ -126,7 +126,7 @@ var _ = Describe("Server", func() {
}, },
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100), data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100),
}) })
Consistently(conn.dataWritten.Len).Should(BeZero()) Consistently(conn.dataWritten).ShouldNot(Receive())
}) })
It("drops packets with a too short connection ID", func() { It("drops packets with a too short connection ID", func() {
@ -140,7 +140,7 @@ var _ = Describe("Server", func() {
}, },
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
}) })
Consistently(conn.dataWritten.Len).Should(BeZero()) Consistently(conn.dataWritten).ShouldNot(Receive())
}) })
It("drops non-Initial packets", func() { It("drops non-Initial packets", func() {
@ -208,6 +208,7 @@ var _ = Describe("Server", func() {
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
serv.handlePacket(&receivedPacket{ serv.handlePacket(&receivedPacket{
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
hdr: &wire.Header{ hdr: &wire.Header{
IsLongHeader: true, IsLongHeader: true,
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -216,8 +217,10 @@ var _ = Describe("Server", func() {
Version: 0x42, Version: 0x42,
}, },
}) })
Expect(conn.dataWritten.Len()).ToNot(BeZero()) var write mockPacketConnWrite
hdr, err := wire.ParseHeader(bytes.NewReader(conn.dataWritten.Bytes()), 0) Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
hdr, err := wire.ParseHeader(bytes.NewReader(write.data), 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(hdr.IsVersionNegotiation()).To(BeTrue()) Expect(hdr.IsVersionNegotiation()).To(BeTrue())
Expect(hdr.DestConnectionID).To(Equal(srcConnID)) Expect(hdr.DestConnectionID).To(Equal(srcConnID))
@ -234,12 +237,14 @@ var _ = Describe("Server", func() {
Version: protocol.VersionTLS, Version: protocol.VersionTLS,
} }
serv.handleInitial(&receivedPacket{ serv.handleInitial(&receivedPacket{
remoteAddr: &net.UDPAddr{}, remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
hdr: hdr, hdr: hdr,
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
}) })
Expect(conn.dataWritten.Len()).ToNot(BeZero()) var write mockPacketConnWrite
replyHdr := parseHeader(conn.dataWritten.Bytes()) Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
replyHdr := parseHeader(write.data)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
@ -288,7 +293,7 @@ var _ = Describe("Server", func() {
defer GinkgoRecover() defer GinkgoRecover()
serv.handlePacket(p) serv.handlePacket(p)
// the Handshake packet is written by the session // the Handshake packet is written by the session
Expect(conn.dataWritten.Len()).To(BeZero()) Consistently(conn.dataWritten).ShouldNot(Receive())
close(done) close(done)
}() }()
// make sure we're using a server-generated connection ID // make sure we're using a server-generated connection ID