From 5e01c49fdf1f3a721db3daadaff367d6ff7eb995 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 27 Nov 2018 09:02:16 +0700 Subject: [PATCH] send Version Negotiation packets in a separate Go routine --- conn_test.go | 41 ++++++++++++++++++++++++++--------------- server.go | 23 +++++++++-------------- server_test.go | 23 ++++++++++++++--------- 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/conn_test.go b/conn_test.go index f7e00cac..d5dedfda 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "errors" "net" "time" @@ -10,19 +9,24 @@ import ( . "github.com/onsi/gomega" ) +type mockPacketConnWrite struct { + data []byte + to net.Addr +} + type mockPacketConn struct { - addr net.Addr - dataToRead chan []byte - dataReadFrom net.Addr - readErr error - dataWritten bytes.Buffer - dataWrittenTo net.Addr - closed bool + addr net.Addr + dataToRead chan []byte + dataReadFrom net.Addr + readErr error + dataWritten chan mockPacketConnWrite + closed bool } func newMockPacketConn() *mockPacketConn { return &mockPacketConn{ - dataToRead: make(chan []byte, 1000), + dataToRead: make(chan []byte, 1000), + dataWritten: make(chan mockPacketConnWrite, 1000), } } @@ -37,10 +41,16 @@ func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { n := copy(b, data) return n, c.dataReadFrom, nil } + func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - c.dataWrittenTo = addr - return c.dataWritten.Write(b) + select { + case c.dataWritten <- mockPacketConnWrite{to: addr, data: b}: + return len(b), nil + default: + panic("channel full") + } } + func (c *mockPacketConn) Close() error { if !c.closed { close(c.dataToRead) @@ -72,10 +82,11 @@ var _ = Describe("Connection", func() { }) It("writes", func() { - err := c.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(packetConn.dataWritten.Bytes()).To(Equal([]byte("foobar"))) - Expect(packetConn.dataWrittenTo.String()).To(Equal("192.168.100.200:1337")) + Expect(c.Write([]byte("foobar"))).To(Succeed()) + var write mockPacketConnWrite + Expect(packetConn.dataWritten).To(Receive(&write)) + Expect(write.to.String()).To(Equal("192.168.100.200:1337")) + Expect(write.data).To(Equal([]byte("foobar"))) }) It("reads", func() { diff --git a/server.go b/server.go index bcbe041b..cdfb2439 100644 --- a/server.go +++ b/server.go @@ -299,23 +299,17 @@ func (s *server) Addr() net.Addr { } func (s *server) handlePacket(p *receivedPacket) { - if err := s.handlePacketImpl(p); err != nil { - s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err) - } -} - -func (s *server) handlePacketImpl(p *receivedPacket) error { hdr := p.hdr // send a Version Negotiation Packet if the client is speaking a different protocol version if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - return s.sendVersionNegotiationPacket(p) + go s.sendVersionNegotiationPacket(p) + return } if hdr.Type == protocol.PacketTypeInitial { go s.handleInitial(p) } // TODO(#943): send Stateless Reset - return nil } func (s *server) handleInitial(p *receivedPacket) { @@ -450,14 +444,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { return nil } -func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error { +func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { hdr := p.hdr - s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) - + s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) if err != nil { - return err + s.logger.Debugf("Error composing Version Negotiation: %s", err) + return + } + if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil { + s.logger.Debugf("Error sending Version Negotiation: %s", err) } - _, err = s.conn.WriteTo(data, p.remoteAddr) - return err } diff --git a/server_test.go b/server_test.go index 0643da39..e9336fb9 100644 --- a/server_test.go +++ b/server_test.go @@ -113,7 +113,7 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, }) - Expect(conn.dataWritten.Len()).To(BeZero()) + Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops too small Initial", func() { @@ -126,7 +126,7 @@ var _ = Describe("Server", func() { }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100), }) - Consistently(conn.dataWritten.Len).Should(BeZero()) + Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops packets with a too short connection ID", func() { @@ -140,7 +140,7 @@ var _ = Describe("Server", func() { }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), }) - Consistently(conn.dataWritten.Len).Should(BeZero()) + Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops non-Initial packets", func() { @@ -208,6 +208,7 @@ var _ = Describe("Server", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} serv.handlePacket(&receivedPacket{ + remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -216,8 +217,10 @@ var _ = Describe("Server", func() { Version: 0x42, }, }) - Expect(conn.dataWritten.Len()).ToNot(BeZero()) - hdr, err := wire.ParseHeader(bytes.NewReader(conn.dataWritten.Bytes()), 0) + var write mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&write)) + Expect(write.to.String()).To(Equal("127.0.0.1:1337")) + hdr, err := wire.ParseHeader(bytes.NewReader(write.data), 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsVersionNegotiation()).To(BeTrue()) Expect(hdr.DestConnectionID).To(Equal(srcConnID)) @@ -234,12 +237,14 @@ var _ = Describe("Server", func() { Version: protocol.VersionTLS, } serv.handleInitial(&receivedPacket{ - remoteAddr: &net.UDPAddr{}, + remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, hdr: hdr, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), }) - Expect(conn.dataWritten.Len()).ToNot(BeZero()) - replyHdr := parseHeader(conn.dataWritten.Bytes()) + var write mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&write)) + Expect(write.to.String()).To(Equal("127.0.0.1:1337")) + replyHdr := parseHeader(write.data) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) @@ -288,7 +293,7 @@ var _ = Describe("Server", func() { defer GinkgoRecover() serv.handlePacket(p) // the Handshake packet is written by the session - Expect(conn.dataWritten.Len()).To(BeZero()) + Consistently(conn.dataWritten).ShouldNot(Receive()) close(done) }() // make sure we're using a server-generated connection ID