mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
send Version Negotiation packets in a separate Go routine
This commit is contained in:
parent
6c4116e7f4
commit
5e01c49fdf
3 changed files with 49 additions and 38 deletions
41
conn_test.go
41
conn_test.go
|
@ -1,7 +1,6 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
@ -10,19 +9,24 @@ import (
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mockPacketConnWrite struct {
|
||||||
|
data []byte
|
||||||
|
to net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
type mockPacketConn struct {
|
type mockPacketConn struct {
|
||||||
addr net.Addr
|
addr net.Addr
|
||||||
dataToRead chan []byte
|
dataToRead chan []byte
|
||||||
dataReadFrom net.Addr
|
dataReadFrom net.Addr
|
||||||
readErr error
|
readErr error
|
||||||
dataWritten bytes.Buffer
|
dataWritten chan mockPacketConnWrite
|
||||||
dataWrittenTo net.Addr
|
closed bool
|
||||||
closed bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockPacketConn() *mockPacketConn {
|
func newMockPacketConn() *mockPacketConn {
|
||||||
return &mockPacketConn{
|
return &mockPacketConn{
|
||||||
dataToRead: make(chan []byte, 1000),
|
dataToRead: make(chan []byte, 1000),
|
||||||
|
dataWritten: make(chan mockPacketConnWrite, 1000),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,10 +41,16 @@ func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||||
n := copy(b, data)
|
n := copy(b, data)
|
||||||
return n, c.dataReadFrom, nil
|
return n, c.dataReadFrom, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
c.dataWrittenTo = addr
|
select {
|
||||||
return c.dataWritten.Write(b)
|
case c.dataWritten <- mockPacketConnWrite{to: addr, data: b}:
|
||||||
|
return len(b), nil
|
||||||
|
default:
|
||||||
|
panic("channel full")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *mockPacketConn) Close() error {
|
func (c *mockPacketConn) Close() error {
|
||||||
if !c.closed {
|
if !c.closed {
|
||||||
close(c.dataToRead)
|
close(c.dataToRead)
|
||||||
|
@ -72,10 +82,11 @@ var _ = Describe("Connection", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("writes", func() {
|
It("writes", func() {
|
||||||
err := c.Write([]byte("foobar"))
|
Expect(c.Write([]byte("foobar"))).To(Succeed())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
var write mockPacketConnWrite
|
||||||
Expect(packetConn.dataWritten.Bytes()).To(Equal([]byte("foobar")))
|
Expect(packetConn.dataWritten).To(Receive(&write))
|
||||||
Expect(packetConn.dataWrittenTo.String()).To(Equal("192.168.100.200:1337"))
|
Expect(write.to.String()).To(Equal("192.168.100.200:1337"))
|
||||||
|
Expect(write.data).To(Equal([]byte("foobar")))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("reads", func() {
|
It("reads", func() {
|
||||||
|
|
23
server.go
23
server.go
|
@ -299,23 +299,17 @@ func (s *server) Addr() net.Addr {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) handlePacket(p *receivedPacket) {
|
func (s *server) handlePacket(p *receivedPacket) {
|
||||||
if err := s.handlePacketImpl(p); err != nil {
|
|
||||||
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) handlePacketImpl(p *receivedPacket) error {
|
|
||||||
hdr := p.hdr
|
hdr := p.hdr
|
||||||
|
|
||||||
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
||||||
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
||||||
return s.sendVersionNegotiationPacket(p)
|
go s.sendVersionNegotiationPacket(p)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if hdr.Type == protocol.PacketTypeInitial {
|
if hdr.Type == protocol.PacketTypeInitial {
|
||||||
go s.handleInitial(p)
|
go s.handleInitial(p)
|
||||||
}
|
}
|
||||||
// TODO(#943): send Stateless Reset
|
// TODO(#943): send Stateless Reset
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) handleInitial(p *receivedPacket) {
|
func (s *server) handleInitial(p *receivedPacket) {
|
||||||
|
@ -450,14 +444,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
|
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
|
||||||
hdr := p.hdr
|
hdr := p.hdr
|
||||||
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
|
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
|
||||||
|
|
||||||
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
s.logger.Debugf("Error composing Version Negotiation: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
|
||||||
|
s.logger.Debugf("Error sending Version Negotiation: %s", err)
|
||||||
}
|
}
|
||||||
_, err = s.conn.WriteTo(data, p.remoteAddr)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,7 +113,7 @@ var _ = Describe("Server", func() {
|
||||||
Version: serv.config.Versions[0],
|
Version: serv.config.Versions[0],
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
Expect(conn.dataWritten.Len()).To(BeZero())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops too small Initial", func() {
|
It("drops too small Initial", func() {
|
||||||
|
@ -126,7 +126,7 @@ var _ = Describe("Server", func() {
|
||||||
},
|
},
|
||||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100),
|
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100),
|
||||||
})
|
})
|
||||||
Consistently(conn.dataWritten.Len).Should(BeZero())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops packets with a too short connection ID", func() {
|
It("drops packets with a too short connection ID", func() {
|
||||||
|
@ -140,7 +140,7 @@ var _ = Describe("Server", func() {
|
||||||
},
|
},
|
||||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||||
})
|
})
|
||||||
Consistently(conn.dataWritten.Len).Should(BeZero())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops non-Initial packets", func() {
|
It("drops non-Initial packets", func() {
|
||||||
|
@ -208,6 +208,7 @@ var _ = Describe("Server", func() {
|
||||||
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||||
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
|
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
|
||||||
serv.handlePacket(&receivedPacket{
|
serv.handlePacket(&receivedPacket{
|
||||||
|
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
|
@ -216,8 +217,10 @@ var _ = Describe("Server", func() {
|
||||||
Version: 0x42,
|
Version: 0x42,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
Expect(conn.dataWritten.Len()).ToNot(BeZero())
|
var write mockPacketConnWrite
|
||||||
hdr, err := wire.ParseHeader(bytes.NewReader(conn.dataWritten.Bytes()), 0)
|
Eventually(conn.dataWritten).Should(Receive(&write))
|
||||||
|
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
||||||
|
hdr, err := wire.ParseHeader(bytes.NewReader(write.data), 0)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(hdr.IsVersionNegotiation()).To(BeTrue())
|
Expect(hdr.IsVersionNegotiation()).To(BeTrue())
|
||||||
Expect(hdr.DestConnectionID).To(Equal(srcConnID))
|
Expect(hdr.DestConnectionID).To(Equal(srcConnID))
|
||||||
|
@ -234,12 +237,14 @@ var _ = Describe("Server", func() {
|
||||||
Version: protocol.VersionTLS,
|
Version: protocol.VersionTLS,
|
||||||
}
|
}
|
||||||
serv.handleInitial(&receivedPacket{
|
serv.handleInitial(&receivedPacket{
|
||||||
remoteAddr: &net.UDPAddr{},
|
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
||||||
hdr: hdr,
|
hdr: hdr,
|
||||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||||
})
|
})
|
||||||
Expect(conn.dataWritten.Len()).ToNot(BeZero())
|
var write mockPacketConnWrite
|
||||||
replyHdr := parseHeader(conn.dataWritten.Bytes())
|
Eventually(conn.dataWritten).Should(Receive(&write))
|
||||||
|
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
||||||
|
replyHdr := parseHeader(write.data)
|
||||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||||
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
|
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
|
||||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||||
|
@ -288,7 +293,7 @@ var _ = Describe("Server", func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
serv.handlePacket(p)
|
serv.handlePacket(p)
|
||||||
// the Handshake packet is written by the session
|
// the Handshake packet is written by the session
|
||||||
Expect(conn.dataWritten.Len()).To(BeZero())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
// make sure we're using a server-generated connection ID
|
// make sure we're using a server-generated connection ID
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue