From 5b42675da259feb5de0b70135964a75cbe66b98c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 15 Feb 2017 18:59:15 +0700 Subject: [PATCH] use a net.PacketConn instead of a net.UDPConn in Server and Session --- benchmark_test.go | 4 +- conn.go | 39 +++++++++++++ conn_test.go | 65 +++++++++++++++++++++ h2quic/server.go | 2 +- h2quic/server_test.go | 2 +- server.go | 8 +-- session.go | 18 ++++-- session_test.go | 132 ++++++++++++++++++++++++++---------------- udp_conn.go | 39 ------------- 9 files changed, 205 insertions(+), 104 deletions(-) create mode 100644 conn.go create mode 100644 conn_test.go delete mode 100644 udp_conn.go diff --git a/benchmark_test.go b/benchmark_test.go index 1b82588b..b74d26eb 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -62,8 +62,8 @@ func (c *linkedConnection) write(p []byte) error { return nil } -func (*linkedConnection) setCurrentRemoteAddr(addr interface{}) {} -func (*linkedConnection) RemoteAddr() *net.UDPAddr { return &net.UDPAddr{} } +func (*linkedConnection) setCurrentRemoteAddr(addr net.Addr) {} +func (*linkedConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} } func setAEAD(cs handshake.CryptoSetup, aead crypto.AEAD) { *(*bool)(unsafe.Pointer(reflect.ValueOf(cs).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true diff --git a/conn.go b/conn.go new file mode 100644 index 00000000..7f81af35 --- /dev/null +++ b/conn.go @@ -0,0 +1,39 @@ +package quic + +import ( + "net" + "sync" +) + +type connection interface { + write([]byte) error + setCurrentRemoteAddr(net.Addr) + RemoteAddr() net.Addr +} + +type conn struct { + mutex sync.RWMutex + + pconn net.PacketConn + currentAddr net.Addr +} + +var _ connection = &conn{} + +func (c *conn) write(p []byte) error { + _, err := c.pconn.WriteTo(p, c.currentAddr) + return err +} + +func (c *conn) setCurrentRemoteAddr(addr net.Addr) { + c.mutex.Lock() + c.currentAddr = addr + c.mutex.Unlock() +} + +func (c *conn) RemoteAddr() net.Addr { + c.mutex.RLock() + addr := c.currentAddr + c.mutex.RUnlock() + return addr +} diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 00000000..c09c8f8b --- /dev/null +++ b/conn_test.go @@ -0,0 +1,65 @@ +package quic + +import ( + "bytes" + "net" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type mockPacketConn struct { + dataWritten bytes.Buffer + dataWrittenTo net.Addr +} + +func (c *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + panic("not implemented") +} +func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + c.dataWrittenTo = addr + return c.dataWritten.Write(b) +} +func (c *mockPacketConn) Close() error { panic("not implemented") } +func (c *mockPacketConn) LocalAddr() net.Addr { panic("not implemented") } +func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") } +func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") } +func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { panic("not implemented") } + +var _ net.PacketConn = &mockPacketConn{} + +var _ = Describe("Connection", func() { + var c *conn + + BeforeEach(func() { + addr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 100, 200), + Port: 1337, + } + c = &conn{ + currentAddr: addr, + pconn: &mockPacketConn{}, + } + }) + + It("writes", func() { + err := c.write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(c.pconn.(*mockPacketConn).dataWritten.Bytes()).To(Equal([]byte("foobar"))) + Expect(c.pconn.(*mockPacketConn).dataWrittenTo.String()).To(Equal("192.168.100.200:1337")) + }) + + It("gets the remote address", func() { + Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337")) + }) + + It("changes the remote address", func() { + addr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 7331, + } + c.setCurrentRemoteAddr(addr) + Expect(c.RemoteAddr().String()).To(Equal(addr.String())) + }) +}) diff --git a/h2quic/server.go b/h2quic/server.go index 82166310..5ec2dd73 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -22,7 +22,7 @@ import ( type streamCreator interface { GetOrOpenStream(protocol.StreamID) (utils.Stream, error) Close(error) error - RemoteAddr() *net.UDPAddr + RemoteAddr() net.Addr } // Server is a HTTP2 server listening for QUIC connections. diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 1f9a22c7..cf0bb5d0 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -36,7 +36,7 @@ func (s *mockSession) Close(e error) error { s.closedWithError = e return nil } -func (s *mockSession) RemoteAddr() *net.UDPAddr { +func (s *mockSession) RemoteAddr() net.Addr { return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42} } diff --git a/server.go b/server.go index d0163104..5af9917f 100644 --- a/server.go +++ b/server.go @@ -132,7 +132,7 @@ func (s *Server) Addr() net.Addr { return s.addr } -func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet []byte) error { +func (s *Server) handlePacket(pconn net.PacketConn, remoteAddr *net.UDPAddr, packet []byte) error { if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize { return qerr.PacketTooLarge } @@ -177,13 +177,13 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet // Send Version Negotiation Packet if the client is speaking a different protocol version if hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) { utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber) - _, err = conn.WriteToUDP(composeVersionNegotiation(hdr.ConnectionID), remoteAddr) + _, err = pconn.WriteTo(composeVersionNegotiation(hdr.ConnectionID), remoteAddr) return err } if !ok { if !hdr.VersionFlag { - _, err = conn.WriteToUDP(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr) + _, err = pconn.WriteTo(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr) return err } version := hdr.VersionNumber @@ -193,7 +193,7 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr) session, err = s.newSession( - &udpConn{conn: conn, currentAddr: remoteAddr}, + &conn{pconn: pconn, currentAddr: remoteAddr}, version, hdr.ConnectionID, s.scfg, diff --git a/session.go b/session.go index 3dd421cb..3d750695 100644 --- a/session.go +++ b/session.go @@ -23,7 +23,7 @@ type unpacker interface { } type receivedPacket struct { - remoteAddr interface{} + remoteAddr net.Addr publicHeader *PublicHeader data []byte rcvTime time.Time @@ -116,8 +116,14 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol session.setup() cryptoStream, _ := session.GetOrOpenStream(1) + var sourceAddr []byte + if udpAddr, ok := conn.RemoteAddr().(*net.UDPAddr); ok { + sourceAddr = udpAddr.IP + } else { + sourceAddr = []byte(conn.RemoteAddr().String()) + } var err error - session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.RemoteAddr().IP, v, sCfg, cryptoStream, session.connectionParameters, session.aeadChanged) + session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, session.connectionParameters, session.aeadChanged) if err != nil { return nil, err } @@ -128,9 +134,9 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return session, err } -func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) { +func newClientSession(pconn net.PacketConn, addr net.Addr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) { session := &Session{ - conn: &udpConn{conn: conn, currentAddr: addr}, + conn: &conn{pconn: pconn, currentAddr: addr}, connectionID: connectionID, perspective: protocol.PerspectiveClient, version: v, @@ -765,7 +771,7 @@ func (s *Session) ackAlarmChanged(t time.Time) { s.maybeResetTimer() } -// RemoteAddr returns the net.UDPAddr of the client -func (s *Session) RemoteAddr() *net.UDPAddr { +// RemoteAddr returns the net.Addr of the client +func (s *Session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } diff --git a/session_test.go b/session_test.go index ca3337f7..b6b7aa83 100644 --- a/session_test.go +++ b/session_test.go @@ -25,7 +25,7 @@ import ( ) type mockConnection struct { - remoteAddr net.IP + remoteAddr net.Addr written [][]byte } @@ -36,12 +36,10 @@ func (m *mockConnection) write(p []byte) error { return nil } -func (m *mockConnection) setCurrentRemoteAddr(addr interface{}) { - if ip, ok := addr.(net.IP); ok { - m.remoteAddr = ip - } +func (m *mockConnection) setCurrentRemoteAddr(addr net.Addr) { + m.remoteAddr = addr } -func (*mockConnection) RemoteAddr() *net.UDPAddr { return &net.UDPAddr{} } +func (*mockConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} } type mockUnpacker struct { unpackErr error @@ -120,22 +118,23 @@ var _ = Describe("Session", func() { clientSession *Session streamCallbackCalled bool closeCallbackCalled bool - conn *mockConnection + scfg *handshake.ServerConfig + mconn *mockConnection cpm *mockConnectionParametersManager ) BeforeEach(func() { - conn = &mockConnection{} + mconn = &mockConnection{} streamCallbackCalled = false closeCallbackCalled = false certChain := crypto.NewCertChain(testdata.GetTLSConfig()) kex, err := crypto.NewCurve25519KEX() Expect(err).NotTo(HaveOccurred()) - scfg, err := handshake.NewServerConfig(kex, certChain) + scfg, err = handshake.NewServerConfig(kex, certChain) Expect(err).NotTo(HaveOccurred()) pSession, err := newSession( - conn, + mconn, protocol.Version35, 0, scfg, @@ -163,7 +162,38 @@ var _ = Describe("Session", func() { ) Expect(err).ToNot(HaveOccurred()) Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream + }) + Context("source address", func() { + It("uses the IP address if given an UDP connection", func() { + conn := &conn{currentAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}} + session, err := newSession( + conn, + protocol.VersionWhatever, + 0, + scfg, + func(*Session, utils.Stream) { streamCallbackCalled = true }, + func(protocol.ConnectionID) { closeCallbackCalled = true }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(session.(*Session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200})) + }) + + It("uses the string representation of the remote addresses if not given a UDP connection", func() { + conn := &conn{ + currentAddr: &net.TCPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}, + } + session, err := newSession( + conn, + protocol.VersionWhatever, + 0, + scfg, + func(*Session, utils.Stream) { streamCallbackCalled = true }, + func(protocol.ConnectionID) { closeCallbackCalled = true }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(session.(*Session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337"))) + }) }) Context("when handling stream frames", func() { @@ -617,8 +647,8 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { session.Close(nil) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) - Expect(conn.written).To(HaveLen(1)) - Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) + Expect(mconn.written).To(HaveLen(1)) + Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) Expect(closeCallbackCalled).To(BeTrue()) Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() }) @@ -627,7 +657,7 @@ var _ = Describe("Session", func() { session.Close(nil) session.Close(nil) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) - Expect(conn.written).To(HaveLen(1)) + Expect(mconn.written).To(HaveLen(1)) Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() }) @@ -652,7 +682,7 @@ var _ = Describe("Session", func() { Expect(closeCallbackCalled).To(BeFalse()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Expect(atomic.LoadUint32(&session.closed) != 0).To(BeTrue()) - Expect(conn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent + Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent }) }) @@ -712,7 +742,7 @@ var _ = Describe("Session", func() { Context("updating the remote address", func() { It("sets the remote address", func() { - remoteIP := net.IPv4(192, 168, 0, 100) + remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) p := receivedPacket{ remoteAddr: remoteIP, @@ -724,8 +754,8 @@ var _ = Describe("Session", func() { }) It("doesn't change the remote address if authenticating the packet fails", func() { - remoteIP := net.IPv4(192, 168, 0, 100) - attackerIP := net.IPv4(192, 168, 0, 102) + remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} + attackerIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 102)} session.conn.(*mockConnection).remoteAddr = remoteIP // use the real packetUnpacker here, to make sure this test fails if the error code for failed decryption changes session.unpacker = &packetUnpacker{} @@ -742,7 +772,7 @@ var _ = Describe("Session", func() { It("sets the remote address, if the packet is authenticated, but unpacking fails for another reason", func() { testErr := errors.New("testErr") - remoteIP := net.IPv4(192, 168, 0, 100) + remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) p := receivedPacket{ remoteAddr: remoteIP, @@ -762,8 +792,8 @@ var _ = Describe("Session", func() { session.receivedPacketHandler.ReceivedPacket(packetNumber, true) err := session.sendPacket() Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(1)) - Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x5E, 0x03}))) + Expect(mconn.written).To(HaveLen(1)) + Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x5E, 0x03}))) }) It("sends two WindowUpdate frames", func() { @@ -776,16 +806,16 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) err = session.sendPacket() Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(2)) - Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) - Expect(conn.written[1]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) + Expect(mconn.written).To(HaveLen(2)) + Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) + Expect(mconn.written[1]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) }) It("sends public reset", func() { err := session.sendPublicReset(1) Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(1)) - Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST")))) + Expect(mconn.written).To(HaveLen(1)) + Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST")))) }) }) @@ -808,9 +838,9 @@ var _ = Describe("Session", func() { err := session.sendPacket() Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(1)) + Expect(mconn.written).To(HaveLen(1)) Expect(sph.(*mockSentPacketHandler).requestedStopWaiting).To(BeTrue()) - Expect(conn.written[0]).To(ContainSubstring("foobar1234567")) + Expect(mconn.written[0]).To(ContainSubstring("foobar1234567")) }) It("sends a StreamFrame from a packet queued for retransmission", func() { @@ -839,9 +869,9 @@ var _ = Describe("Session", func() { err := session.sendPacket() Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(1)) - Expect(conn.written[0]).To(ContainSubstring("foobar")) - Expect(conn.written[0]).To(ContainSubstring("loremipsum")) + Expect(mconn.written).To(HaveLen(1)) + Expect(mconn.written[0]).To(ContainSubstring("foobar")) + Expect(mconn.written[0]).To(ContainSubstring("loremipsum")) }) It("always attaches a StopWaiting to a packet that contains a retransmission", func() { @@ -859,7 +889,7 @@ var _ = Describe("Session", func() { err := session.sendPacket() Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(2)) + Expect(mconn.written).To(HaveLen(2)) sentPackets := sph.(*mockSentPacketHandler).sentPackets Expect(sentPackets).To(HaveLen(2)) _, ok := sentPackets[0].Frames[0].(*frames.StopWaitingFrame) @@ -963,8 +993,8 @@ var _ = Describe("Session", func() { go session.run() session.ackAlarmChanged(time.Now().Add(10 * time.Millisecond)) time.Sleep(10 * time.Millisecond) - Eventually(func() int { return len(conn.written) }).ShouldNot(BeZero()) - Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13}))) + Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero()) + Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13}))) }) Context("bundling of small packets", func() { @@ -981,9 +1011,9 @@ var _ = Describe("Session", func() { session.scheduleSending() go session.run() - Eventually(func() [][]byte { return conn.written }).Should(HaveLen(1)) - Expect(conn.written[0]).To(ContainSubstring("foobar1")) - Expect(conn.written[0]).To(ContainSubstring("foobar2")) + Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1)) + Expect(mconn.written[0]).To(ContainSubstring("foobar1")) + Expect(mconn.written[0]).To(ContainSubstring("foobar2")) }) It("sends out two big frames in two packets", func() { @@ -999,7 +1029,7 @@ var _ = Describe("Session", func() { }() _, err = s2.Write(bytes.Repeat([]byte{'e'}, 1000)) Expect(err).ToNot(HaveOccurred()) - Eventually(func() [][]byte { return conn.written }).Should(HaveLen(2)) + Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(2)) }) It("sends out two small frames that are written to long after one another into two packets", func() { @@ -1008,10 +1038,10 @@ var _ = Describe("Session", func() { go session.run() _, err = s.Write([]byte("foobar1")) Expect(err).NotTo(HaveOccurred()) - Eventually(func() [][]byte { return conn.written }).Should(HaveLen(1)) + Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1)) _, err = s.Write([]byte("foobar2")) Expect(err).NotTo(HaveOccurred()) - Eventually(func() [][]byte { return conn.written }).Should(HaveLen(2)) + Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(2)) }) It("sends a queued ACK frame only once", func() { @@ -1023,13 +1053,13 @@ var _ = Describe("Session", func() { go session.run() _, err = s.Write([]byte("foobar1")) Expect(err).NotTo(HaveOccurred()) - Eventually(func() [][]byte { return conn.written }).Should(HaveLen(1)) + Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1)) _, err = s.Write([]byte("foobar2")) Expect(err).NotTo(HaveOccurred()) - Eventually(func() [][]byte { return conn.written }).Should(HaveLen(2)) - Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13}))) - Expect(conn.written[1]).ToNot(ContainSubstring(string([]byte{0x37, 0x13}))) + Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(2)) + Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13}))) + Expect(mconn.written[1]).ToNot(ContainSubstring(string([]byte{0x37, 0x13}))) }) }) }) @@ -1058,8 +1088,8 @@ var _ = Describe("Session", func() { } session.run() - Expect(conn.written).To(HaveLen(1)) - Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST")))) + Expect(mconn.written).To(HaveLen(1)) + Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST")))) Expect(session.runClosed).To(Receive()) }) @@ -1120,7 +1150,7 @@ var _ = Describe("Session", func() { It("times out due to no network activity", func(done Done) { session.lastNetworkActivityTime = time.Now().Add(-time.Hour) session.run() // Would normally not return - Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) + Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) Expect(closeCallbackCalled).To(BeTrue()) Expect(session.runClosed).To(Receive()) close(done) @@ -1129,7 +1159,7 @@ var _ = Describe("Session", func() { It("times out due to non-completed crypto handshake", func(done Done) { session.sessionCreationTime = time.Now().Add(-time.Hour) session.run() // Would normally not return - Expect(conn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time.")) + Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time.")) Expect(closeCallbackCalled).To(BeTrue()) Expect(session.runClosed).To(Receive()) close(done) @@ -1140,7 +1170,7 @@ var _ = Describe("Session", func() { cpm.idleTime = 99999 * time.Second session.packer.connectionParameters = session.connectionParameters session.run() // Would normally not return - Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) + Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) Expect(closeCallbackCalled).To(BeTrue()) Expect(session.runClosed).To(Receive()) close(done) @@ -1153,7 +1183,7 @@ var _ = Describe("Session", func() { cpm.idleTime = 0 * time.Millisecond session.packer.connectionParameters = session.connectionParameters session.run() // Would normally not return - Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) + Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) Expect(closeCallbackCalled).To(BeTrue()) Expect(session.runClosed).To(Receive()) close(done) @@ -1204,8 +1234,8 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) go session.run() session.scheduleSending() - Eventually(func() [][]byte { return conn.written }).ShouldNot(BeEmpty()) - Expect(conn.written[0]).To(ContainSubstring("foobar")) + Eventually(func() [][]byte { return mconn.written }).ShouldNot(BeEmpty()) + Expect(mconn.written[0]).To(ContainSubstring("foobar")) }) Context("getting streams", func() { diff --git a/udp_conn.go b/udp_conn.go deleted file mode 100644 index efc646d0..00000000 --- a/udp_conn.go +++ /dev/null @@ -1,39 +0,0 @@ -package quic - -import ( - "net" - "sync" -) - -type connection interface { - write([]byte) error - setCurrentRemoteAddr(interface{}) - RemoteAddr() *net.UDPAddr -} - -type udpConn struct { - mutex sync.RWMutex - - conn *net.UDPConn - currentAddr *net.UDPAddr -} - -var _ connection = &udpConn{} - -func (c *udpConn) write(p []byte) error { - _, err := c.conn.WriteToUDP(p, c.currentAddr) - return err -} - -func (c *udpConn) setCurrentRemoteAddr(addr interface{}) { - c.mutex.Lock() - c.currentAddr = addr.(*net.UDPAddr) - c.mutex.Unlock() -} - -func (c *udpConn) RemoteAddr() *net.UDPAddr { - c.mutex.RLock() - addr := c.currentAddr - c.mutex.RUnlock() - return addr -}