use a net.PacketConn instead of a net.UDPConn in Server and Session

This commit is contained in:
Marten Seemann 2017-02-15 18:59:15 +07:00
parent 7fe2a37c76
commit 5b42675da2
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
9 changed files with 205 additions and 104 deletions

View file

@ -62,8 +62,8 @@ func (c *linkedConnection) write(p []byte) error {
return nil
}
func (*linkedConnection) setCurrentRemoteAddr(addr interface{}) {}
func (*linkedConnection) RemoteAddr() *net.UDPAddr { return &net.UDPAddr{} }
func (*linkedConnection) setCurrentRemoteAddr(addr net.Addr) {}
func (*linkedConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} }
func setAEAD(cs handshake.CryptoSetup, aead crypto.AEAD) {
*(*bool)(unsafe.Pointer(reflect.ValueOf(cs).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true

39
conn.go Normal file
View file

@ -0,0 +1,39 @@
package quic
import (
"net"
"sync"
)
type connection interface {
write([]byte) error
setCurrentRemoteAddr(net.Addr)
RemoteAddr() net.Addr
}
type conn struct {
mutex sync.RWMutex
pconn net.PacketConn
currentAddr net.Addr
}
var _ connection = &conn{}
func (c *conn) write(p []byte) error {
_, err := c.pconn.WriteTo(p, c.currentAddr)
return err
}
func (c *conn) setCurrentRemoteAddr(addr net.Addr) {
c.mutex.Lock()
c.currentAddr = addr
c.mutex.Unlock()
}
func (c *conn) RemoteAddr() net.Addr {
c.mutex.RLock()
addr := c.currentAddr
c.mutex.RUnlock()
return addr
}

65
conn_test.go Normal file
View file

@ -0,0 +1,65 @@
package quic
import (
"bytes"
"net"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockPacketConn struct {
dataWritten bytes.Buffer
dataWrittenTo net.Addr
}
func (c *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
panic("not implemented")
}
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
c.dataWrittenTo = addr
return c.dataWritten.Write(b)
}
func (c *mockPacketConn) Close() error { panic("not implemented") }
func (c *mockPacketConn) LocalAddr() net.Addr { panic("not implemented") }
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }
var _ net.PacketConn = &mockPacketConn{}
var _ = Describe("Connection", func() {
var c *conn
BeforeEach(func() {
addr := &net.UDPAddr{
IP: net.IPv4(192, 168, 100, 200),
Port: 1337,
}
c = &conn{
currentAddr: addr,
pconn: &mockPacketConn{},
}
})
It("writes", func() {
err := c.write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(c.pconn.(*mockPacketConn).dataWritten.Bytes()).To(Equal([]byte("foobar")))
Expect(c.pconn.(*mockPacketConn).dataWrittenTo.String()).To(Equal("192.168.100.200:1337"))
})
It("gets the remote address", func() {
Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337"))
})
It("changes the remote address", func() {
addr := &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 7331,
}
c.setCurrentRemoteAddr(addr)
Expect(c.RemoteAddr().String()).To(Equal(addr.String()))
})
})

View file

@ -22,7 +22,7 @@ import (
type streamCreator interface {
GetOrOpenStream(protocol.StreamID) (utils.Stream, error)
Close(error) error
RemoteAddr() *net.UDPAddr
RemoteAddr() net.Addr
}
// Server is a HTTP2 server listening for QUIC connections.

View file

@ -36,7 +36,7 @@ func (s *mockSession) Close(e error) error {
s.closedWithError = e
return nil
}
func (s *mockSession) RemoteAddr() *net.UDPAddr {
func (s *mockSession) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
}

View file

@ -132,7 +132,7 @@ func (s *Server) Addr() net.Addr {
return s.addr
}
func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet []byte) error {
func (s *Server) handlePacket(pconn net.PacketConn, remoteAddr *net.UDPAddr, packet []byte) error {
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
return qerr.PacketTooLarge
}
@ -177,13 +177,13 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
// Send Version Negotiation Packet if the client is speaking a different protocol version
if hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) {
utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber)
_, err = conn.WriteToUDP(composeVersionNegotiation(hdr.ConnectionID), remoteAddr)
_, err = pconn.WriteTo(composeVersionNegotiation(hdr.ConnectionID), remoteAddr)
return err
}
if !ok {
if !hdr.VersionFlag {
_, err = conn.WriteToUDP(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr)
_, err = pconn.WriteTo(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr)
return err
}
version := hdr.VersionNumber
@ -193,7 +193,7 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr)
session, err = s.newSession(
&udpConn{conn: conn, currentAddr: remoteAddr},
&conn{pconn: pconn, currentAddr: remoteAddr},
version,
hdr.ConnectionID,
s.scfg,

View file

@ -23,7 +23,7 @@ type unpacker interface {
}
type receivedPacket struct {
remoteAddr interface{}
remoteAddr net.Addr
publicHeader *PublicHeader
data []byte
rcvTime time.Time
@ -116,8 +116,14 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
session.setup()
cryptoStream, _ := session.GetOrOpenStream(1)
var sourceAddr []byte
if udpAddr, ok := conn.RemoteAddr().(*net.UDPAddr); ok {
sourceAddr = udpAddr.IP
} else {
sourceAddr = []byte(conn.RemoteAddr().String())
}
var err error
session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.RemoteAddr().IP, v, sCfg, cryptoStream, session.connectionParameters, session.aeadChanged)
session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, session.connectionParameters, session.aeadChanged)
if err != nil {
return nil, err
}
@ -128,9 +134,9 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
return session, err
}
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) {
func newClientSession(pconn net.PacketConn, addr net.Addr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) {
session := &Session{
conn: &udpConn{conn: conn, currentAddr: addr},
conn: &conn{pconn: pconn, currentAddr: addr},
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
@ -765,7 +771,7 @@ func (s *Session) ackAlarmChanged(t time.Time) {
s.maybeResetTimer()
}
// RemoteAddr returns the net.UDPAddr of the client
func (s *Session) RemoteAddr() *net.UDPAddr {
// RemoteAddr returns the net.Addr of the client
func (s *Session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}

View file

@ -25,7 +25,7 @@ import (
)
type mockConnection struct {
remoteAddr net.IP
remoteAddr net.Addr
written [][]byte
}
@ -36,12 +36,10 @@ func (m *mockConnection) write(p []byte) error {
return nil
}
func (m *mockConnection) setCurrentRemoteAddr(addr interface{}) {
if ip, ok := addr.(net.IP); ok {
m.remoteAddr = ip
}
func (m *mockConnection) setCurrentRemoteAddr(addr net.Addr) {
m.remoteAddr = addr
}
func (*mockConnection) RemoteAddr() *net.UDPAddr { return &net.UDPAddr{} }
func (*mockConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} }
type mockUnpacker struct {
unpackErr error
@ -120,22 +118,23 @@ var _ = Describe("Session", func() {
clientSession *Session
streamCallbackCalled bool
closeCallbackCalled bool
conn *mockConnection
scfg *handshake.ServerConfig
mconn *mockConnection
cpm *mockConnectionParametersManager
)
BeforeEach(func() {
conn = &mockConnection{}
mconn = &mockConnection{}
streamCallbackCalled = false
closeCallbackCalled = false
certChain := crypto.NewCertChain(testdata.GetTLSConfig())
kex, err := crypto.NewCurve25519KEX()
Expect(err).NotTo(HaveOccurred())
scfg, err := handshake.NewServerConfig(kex, certChain)
scfg, err = handshake.NewServerConfig(kex, certChain)
Expect(err).NotTo(HaveOccurred())
pSession, err := newSession(
conn,
mconn,
protocol.Version35,
0,
scfg,
@ -163,7 +162,38 @@ var _ = Describe("Session", func() {
)
Expect(err).ToNot(HaveOccurred())
Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream
})
Context("source address", func() {
It("uses the IP address if given an UDP connection", func() {
conn := &conn{currentAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}}
session, err := newSession(
conn,
protocol.VersionWhatever,
0,
scfg,
func(*Session, utils.Stream) { streamCallbackCalled = true },
func(protocol.ConnectionID) { closeCallbackCalled = true },
)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(session.(*Session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200}))
})
It("uses the string representation of the remote addresses if not given a UDP connection", func() {
conn := &conn{
currentAddr: &net.TCPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337},
}
session, err := newSession(
conn,
protocol.VersionWhatever,
0,
scfg,
func(*Session, utils.Stream) { streamCallbackCalled = true },
func(protocol.ConnectionID) { closeCallbackCalled = true },
)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(session.(*Session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337")))
})
})
Context("when handling stream frames", func() {
@ -617,8 +647,8 @@ var _ = Describe("Session", func() {
It("shuts down without error", func() {
session.Close(nil)
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
})
@ -627,7 +657,7 @@ var _ = Describe("Session", func() {
session.Close(nil)
session.Close(nil)
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(conn.written).To(HaveLen(1))
Expect(mconn.written).To(HaveLen(1))
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
})
@ -652,7 +682,7 @@ var _ = Describe("Session", func() {
Expect(closeCallbackCalled).To(BeFalse())
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(atomic.LoadUint32(&session.closed) != 0).To(BeTrue())
Expect(conn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
})
})
@ -712,7 +742,7 @@ var _ = Describe("Session", func() {
Context("updating the remote address", func() {
It("sets the remote address", func() {
remoteIP := net.IPv4(192, 168, 0, 100)
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP))
p := receivedPacket{
remoteAddr: remoteIP,
@ -724,8 +754,8 @@ var _ = Describe("Session", func() {
})
It("doesn't change the remote address if authenticating the packet fails", func() {
remoteIP := net.IPv4(192, 168, 0, 100)
attackerIP := net.IPv4(192, 168, 0, 102)
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
attackerIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 102)}
session.conn.(*mockConnection).remoteAddr = remoteIP
// use the real packetUnpacker here, to make sure this test fails if the error code for failed decryption changes
session.unpacker = &packetUnpacker{}
@ -742,7 +772,7 @@ var _ = Describe("Session", func() {
It("sets the remote address, if the packet is authenticated, but unpacking fails for another reason", func() {
testErr := errors.New("testErr")
remoteIP := net.IPv4(192, 168, 0, 100)
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP))
p := receivedPacket{
remoteAddr: remoteIP,
@ -762,8 +792,8 @@ var _ = Describe("Session", func() {
session.receivedPacketHandler.ReceivedPacket(packetNumber, true)
err := session.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x5E, 0x03})))
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x5E, 0x03})))
})
It("sends two WindowUpdate frames", func() {
@ -776,16 +806,16 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred())
err = session.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(2))
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0})))
Expect(conn.written[1]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0})))
Expect(mconn.written).To(HaveLen(2))
Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0})))
Expect(mconn.written[1]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0})))
})
It("sends public reset", func() {
err := session.sendPublicReset(1)
Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
})
})
@ -808,9 +838,9 @@ var _ = Describe("Session", func() {
err := session.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1))
Expect(mconn.written).To(HaveLen(1))
Expect(sph.(*mockSentPacketHandler).requestedStopWaiting).To(BeTrue())
Expect(conn.written[0]).To(ContainSubstring("foobar1234567"))
Expect(mconn.written[0]).To(ContainSubstring("foobar1234567"))
})
It("sends a StreamFrame from a packet queued for retransmission", func() {
@ -839,9 +869,9 @@ var _ = Describe("Session", func() {
err := session.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring("foobar"))
Expect(conn.written[0]).To(ContainSubstring("loremipsum"))
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0]).To(ContainSubstring("foobar"))
Expect(mconn.written[0]).To(ContainSubstring("loremipsum"))
})
It("always attaches a StopWaiting to a packet that contains a retransmission", func() {
@ -859,7 +889,7 @@ var _ = Describe("Session", func() {
err := session.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(2))
Expect(mconn.written).To(HaveLen(2))
sentPackets := sph.(*mockSentPacketHandler).sentPackets
Expect(sentPackets).To(HaveLen(2))
_, ok := sentPackets[0].Frames[0].(*frames.StopWaitingFrame)
@ -963,8 +993,8 @@ var _ = Describe("Session", func() {
go session.run()
session.ackAlarmChanged(time.Now().Add(10 * time.Millisecond))
time.Sleep(10 * time.Millisecond)
Eventually(func() int { return len(conn.written) }).ShouldNot(BeZero())
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13})))
Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero())
Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13})))
})
Context("bundling of small packets", func() {
@ -981,9 +1011,9 @@ var _ = Describe("Session", func() {
session.scheduleSending()
go session.run()
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring("foobar1"))
Expect(conn.written[0]).To(ContainSubstring("foobar2"))
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
Expect(mconn.written[0]).To(ContainSubstring("foobar1"))
Expect(mconn.written[0]).To(ContainSubstring("foobar2"))
})
It("sends out two big frames in two packets", func() {
@ -999,7 +1029,7 @@ var _ = Describe("Session", func() {
}()
_, err = s2.Write(bytes.Repeat([]byte{'e'}, 1000))
Expect(err).ToNot(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(2))
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(2))
})
It("sends out two small frames that are written to long after one another into two packets", func() {
@ -1008,10 +1038,10 @@ var _ = Describe("Session", func() {
go session.run()
_, err = s.Write([]byte("foobar1"))
Expect(err).NotTo(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(1))
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
_, err = s.Write([]byte("foobar2"))
Expect(err).NotTo(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(2))
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(2))
})
It("sends a queued ACK frame only once", func() {
@ -1023,13 +1053,13 @@ var _ = Describe("Session", func() {
go session.run()
_, err = s.Write([]byte("foobar1"))
Expect(err).NotTo(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(1))
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
_, err = s.Write([]byte("foobar2"))
Expect(err).NotTo(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(2))
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13})))
Expect(conn.written[1]).ToNot(ContainSubstring(string([]byte{0x37, 0x13})))
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(2))
Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13})))
Expect(mconn.written[1]).ToNot(ContainSubstring(string([]byte{0x37, 0x13})))
})
})
})
@ -1058,8 +1088,8 @@ var _ = Describe("Session", func() {
}
session.run()
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
Expect(session.runClosed).To(Receive())
})
@ -1120,7 +1150,7 @@ var _ = Describe("Session", func() {
It("times out due to no network activity", func(done Done) {
session.lastNetworkActivityTime = time.Now().Add(-time.Hour)
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive())
close(done)
@ -1129,7 +1159,7 @@ var _ = Describe("Session", func() {
It("times out due to non-completed crypto handshake", func(done Done) {
session.sessionCreationTime = time.Now().Add(-time.Hour)
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive())
close(done)
@ -1140,7 +1170,7 @@ var _ = Describe("Session", func() {
cpm.idleTime = 99999 * time.Second
session.packer.connectionParameters = session.connectionParameters
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive())
close(done)
@ -1153,7 +1183,7 @@ var _ = Describe("Session", func() {
cpm.idleTime = 0 * time.Millisecond
session.packer.connectionParameters = session.connectionParameters
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive())
close(done)
@ -1204,8 +1234,8 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred())
go session.run()
session.scheduleSending()
Eventually(func() [][]byte { return conn.written }).ShouldNot(BeEmpty())
Expect(conn.written[0]).To(ContainSubstring("foobar"))
Eventually(func() [][]byte { return mconn.written }).ShouldNot(BeEmpty())
Expect(mconn.written[0]).To(ContainSubstring("foobar"))
})
Context("getting streams", func() {

View file

@ -1,39 +0,0 @@
package quic
import (
"net"
"sync"
)
type connection interface {
write([]byte) error
setCurrentRemoteAddr(interface{})
RemoteAddr() *net.UDPAddr
}
type udpConn struct {
mutex sync.RWMutex
conn *net.UDPConn
currentAddr *net.UDPAddr
}
var _ connection = &udpConn{}
func (c *udpConn) write(p []byte) error {
_, err := c.conn.WriteToUDP(p, c.currentAddr)
return err
}
func (c *udpConn) setCurrentRemoteAddr(addr interface{}) {
c.mutex.Lock()
c.currentAddr = addr.(*net.UDPAddr)
c.mutex.Unlock()
}
func (c *udpConn) RemoteAddr() *net.UDPAddr {
c.mutex.RLock()
addr := c.currentAddr
c.mutex.RUnlock()
return addr
}