From e924f0ecb3c57d18b4f36c4fd67d82e56908669f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 16 Feb 2017 22:30:17 +0700 Subject: [PATCH] use the net.PacketConn everywhere in the server --- conn_test.go | 16 ++++- server.go | 10 ++-- server_test.go | 156 ++++++++++++++----------------------------------- 3 files changed, 61 insertions(+), 121 deletions(-) diff --git a/conn_test.go b/conn_test.go index c09c8f8b..7490b072 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "io" "net" "time" @@ -10,18 +11,27 @@ import ( ) type mockPacketConn struct { + dataToRead []byte + dataReadFrom net.Addr dataWritten bytes.Buffer dataWrittenTo net.Addr + closed bool } -func (c *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - panic("not implemented") +func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + if c.dataToRead == nil { // block if there's no data + time.Sleep(time.Hour) + return 0, nil, io.EOF + } + n := copy(b, c.dataToRead) + c.dataToRead = nil + return n, c.dataReadFrom, nil } func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { c.dataWrittenTo = addr return c.dataWritten.Write(b) } -func (c *mockPacketConn) Close() error { panic("not implemented") } +func (c *mockPacketConn) Close() error { c.closed = true; return nil } func (c *mockPacketConn) LocalAddr() net.Addr { panic("not implemented") } func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") } func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") } diff --git a/server.go b/server.go index 5af9917f..93d0f4b0 100644 --- a/server.go +++ b/server.go @@ -28,7 +28,7 @@ type packetHandler interface { type Server struct { addr *net.UDPAddr - conn *net.UDPConn + conn net.PacketConn connMutex sync.Mutex certChain crypto.CertChain @@ -81,8 +81,8 @@ func (s *Server) ListenAndServe() error { return s.Serve(conn) } -// Serve on an existing UDP connection. -func (s *Server) Serve(conn *net.UDPConn) error { +// Serve on an existing PacketConn +func (s *Server) Serve(conn net.PacketConn) error { s.connMutex.Lock() s.conn = conn s.connMutex.Unlock() @@ -90,7 +90,7 @@ func (s *Server) Serve(conn *net.UDPConn) error { for { data := getPacketBuffer() data = data[:protocol.MaxPacketSize] - n, remoteAddr, err := conn.ReadFromUDP(data) + n, remoteAddr, err := conn.ReadFrom(data) if err != nil { if strings.HasSuffix(err.Error(), "use of closed network connection") { return nil @@ -132,7 +132,7 @@ func (s *Server) Addr() net.Addr { return s.addr } -func (s *Server) handlePacket(pconn net.PacketConn, remoteAddr *net.UDPAddr, packet []byte) error { +func (s *Server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet []byte) error { if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize { return qerr.PacketTooLarge } diff --git a/server_test.go b/server_test.go index 88c00a8e..815e8910 100644 --- a/server_test.go +++ b/server_test.go @@ -44,23 +44,28 @@ func newMockSession(conn connection, v protocol.VersionNumber, connectionID prot } var _ = Describe("Server", func() { - Describe("with mock session", func() { - var ( - server *Server - firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 - ) + var ( + server *Server + conn *mockPacketConn + udpAddr *net.UDPAddr + firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 + ) - BeforeEach(func() { - server = &Server{ - sessions: map[protocol.ConnectionID]packetHandler{}, - newSession: newMockSession, - } - b := &bytes.Buffer{} - utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.SupportedVersions[0])) - firstPacket = []byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c} - firstPacket = append(append(firstPacket, b.Bytes()...), 0x01) - }) + BeforeEach(func() { + server = &Server{ + sessions: map[protocol.ConnectionID]packetHandler{}, + newSession: newMockSession, + conn: &mockPacketConn{}, + } + conn = server.conn.(*mockPacketConn) + udpAddr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} + b := &bytes.Buffer{} + utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.SupportedVersions[0])) + firstPacket = []byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c} + firstPacket = append(append(firstPacket, b.Bytes()...), 0x01) + }) + Context("with mock session", func() { It("returns the address", func() { server.addr = &net.UDPAddr{ IP: net.IPv4(192, 168, 13, 37), @@ -117,12 +122,13 @@ var _ = Describe("Server", func() { Eventually(func() map[protocol.ConnectionID]packetHandler { return server.sessions }).ShouldNot(HaveKey(protocol.ConnectionID(0x4cfa9f9b668619f6))) }) - It("closes sessions when Close is called", func() { + It("closes sessions and the connection when Close is called", func() { session := &mockSession{} server.sessions[1] = session err := server.Close() Expect(err).NotTo(HaveOccurred()) Expect(session.closed).To(BeTrue()) + Expect(conn.closed).To(BeTrue()) }) It("ignores packets for closed sessions", func() { @@ -145,7 +151,7 @@ var _ = Describe("Server", func() { err = server.handlePacket(nil, nil, data) Expect(err).ToNot(HaveOccurred()) // if we didn't ignore the packet, the server would try to send a version negotation packet, which would make the test panic because it doesn't have a udpConn - // TODO: test that really doesn't send anything on the udpConn + Expect(conn.dataWritten.Bytes()).To(BeEmpty()) // make sure the packet was *not* passed to session.handlePacket() Expect(server.sessions[0x4cfa9f9b668619f6].(*mockSession).packetCount).To(Equal(1)) }) @@ -189,119 +195,43 @@ var _ = Describe("Server", func() { }) It("setups with the right values", func() { - server, err := NewServer("", testdata.GetTLSConfig(), nil) + s, err := NewServer("", testdata.GetTLSConfig(), nil) Expect(err).ToNot(HaveOccurred()) - Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout)) - Expect(server.sessions).ToNot(BeNil()) - Expect(server.scfg).ToNot(BeNil()) + Expect(s.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout)) + Expect(s.sessions).ToNot(BeNil()) + Expect(s.scfg).ToNot(BeNil()) }) - It("setups and responds with version negotiation", func(done Done) { - addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") - Expect(err).ToNot(HaveOccurred()) - - server, err := NewServer("", testdata.GetTLSConfig(), nil) - Expect(err).ToNot(HaveOccurred()) - - serverConn, err := net.ListenUDP("udp", addr) - Expect(err).NotTo(HaveOccurred()) - - addr = serverConn.LocalAddr().(*net.UDPAddr) - + It("setups and responds with version negotiation", func() { + conn.dataToRead = []byte{0x09, 0x01, 0, 0, 0, 0, 0, 0, 0, 0x01, 0x01, 'Q', '0', '0', '0', 0x01} + conn.dataReadFrom = udpAddr go func() { defer GinkgoRecover() - err2 := server.Serve(serverConn) - Expect(err2).ToNot(HaveOccurred()) - close(done) + err := server.Serve(conn) + Expect(err).ToNot(HaveOccurred()) }() - clientConn, err := net.DialUDP("udp", nil, addr) - Expect(err).ToNot(HaveOccurred()) - - _, err = clientConn.Write([]byte{0x09, 0x01, 0, 0, 0, 0, 0, 0, 0, 0x01, 0x01, 'Q', '0', '0', '0', 0x01}) - Expect(err).NotTo(HaveOccurred()) - data := make([]byte, 1000) - var n int - n, _, err = clientConn.ReadFromUDP(data) - Expect(err).NotTo(HaveOccurred()) - data = data[:n] + Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) + Expect(conn.dataWrittenTo).To(Equal(udpAddr)) expected := append( []byte{0x9, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, protocol.SupportedVersionsAsTags..., ) - Expect(data).To(Equal(expected)) - - err = server.Close() - Expect(err).ToNot(HaveOccurred()) + Expect(conn.dataWritten.Bytes()).To(Equal(expected)) }) - It("sends a public reset for new connections that don't have the VersionFlag set", func(done Done) { - addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") - Expect(err).ToNot(HaveOccurred()) - - server, err := NewServer("", testdata.GetTLSConfig(), nil) - Expect(err).ToNot(HaveOccurred()) - - serverConn, err := net.ListenUDP("udp", addr) - Expect(err).NotTo(HaveOccurred()) - - addr = serverConn.LocalAddr().(*net.UDPAddr) - + It("sends a PublicReset for new connections that don't have the VersionFlag set", func() { + conn.dataReadFrom = udpAddr + conn.dataToRead = []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01} go func() { defer GinkgoRecover() - err2 := server.Serve(serverConn) - Expect(err2).ToNot(HaveOccurred()) - close(done) + err := server.Serve(conn) + Expect(err).ToNot(HaveOccurred()) }() - clientConn, err := net.DialUDP("udp", nil, addr) - Expect(err).ToNot(HaveOccurred()) - - _, err = clientConn.Write([]byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01}) - Expect(err).ToNot(HaveOccurred()) - data := make([]byte, 1000) - var n int - n, _, err = clientConn.ReadFromUDP(data) - Expect(err).NotTo(HaveOccurred()) - Expect(n).ToNot(BeZero()) - Expect(data[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set + Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) + Expect(conn.dataWrittenTo).To(Equal(udpAddr)) + Expect(conn.dataWritten.Bytes()[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set Expect(server.sessions).To(BeEmpty()) - - err = server.Close() - Expect(err).ToNot(HaveOccurred()) - }) - - It("setups and responds with error on invalid frame", func(done Done) { - addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") - Expect(err).ToNot(HaveOccurred()) - - server, err := NewServer("", testdata.GetTLSConfig(), nil) - Expect(err).ToNot(HaveOccurred()) - - serverConn, err := net.ListenUDP("udp", addr) - Expect(err).NotTo(HaveOccurred()) - - addr = serverConn.LocalAddr().(*net.UDPAddr) - - go func() { - defer GinkgoRecover() - err2 := server.Serve(serverConn) - Expect(err2).ToNot(HaveOccurred()) - close(done) - }() - - clientConn, err := net.DialUDP("udp", nil, addr) - Expect(err).ToNot(HaveOccurred()) - - _, err = clientConn.Write([]byte{0x09, 0x01, 0, 0, 0, 0, 0, 0, 0, 0x01, 0x01, 'Q', '0', '0', '0', 0x01, 0x00}) - Expect(err).NotTo(HaveOccurred()) - data := make([]byte, 1000) - var n int - n, _, err = clientConn.ReadFromUDP(data) - Expect(err).NotTo(HaveOccurred()) - Expect(n).ToNot(BeZero()) - - err = server.Close() - Expect(err).ToNot(HaveOccurred()) }) })