use a net.PacketConn instead of a net.UDPConn in Server and Session

This commit is contained in:
Marten Seemann 2017-02-15 18:59:15 +07:00
parent 7fe2a37c76
commit 5b42675da2
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
9 changed files with 205 additions and 104 deletions

View file

@ -62,8 +62,8 @@ func (c *linkedConnection) write(p []byte) error {
return nil return nil
} }
func (*linkedConnection) setCurrentRemoteAddr(addr interface{}) {} func (*linkedConnection) setCurrentRemoteAddr(addr net.Addr) {}
func (*linkedConnection) RemoteAddr() *net.UDPAddr { return &net.UDPAddr{} } func (*linkedConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} }
func setAEAD(cs handshake.CryptoSetup, aead crypto.AEAD) { func setAEAD(cs handshake.CryptoSetup, aead crypto.AEAD) {
*(*bool)(unsafe.Pointer(reflect.ValueOf(cs).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true *(*bool)(unsafe.Pointer(reflect.ValueOf(cs).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true

39
conn.go Normal file
View file

@ -0,0 +1,39 @@
package quic
import (
"net"
"sync"
)
type connection interface {
write([]byte) error
setCurrentRemoteAddr(net.Addr)
RemoteAddr() net.Addr
}
type conn struct {
mutex sync.RWMutex
pconn net.PacketConn
currentAddr net.Addr
}
var _ connection = &conn{}
func (c *conn) write(p []byte) error {
_, err := c.pconn.WriteTo(p, c.currentAddr)
return err
}
func (c *conn) setCurrentRemoteAddr(addr net.Addr) {
c.mutex.Lock()
c.currentAddr = addr
c.mutex.Unlock()
}
func (c *conn) RemoteAddr() net.Addr {
c.mutex.RLock()
addr := c.currentAddr
c.mutex.RUnlock()
return addr
}

65
conn_test.go Normal file
View file

@ -0,0 +1,65 @@
package quic
import (
"bytes"
"net"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockPacketConn struct {
dataWritten bytes.Buffer
dataWrittenTo net.Addr
}
func (c *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
panic("not implemented")
}
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
c.dataWrittenTo = addr
return c.dataWritten.Write(b)
}
func (c *mockPacketConn) Close() error { panic("not implemented") }
func (c *mockPacketConn) LocalAddr() net.Addr { panic("not implemented") }
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }
var _ net.PacketConn = &mockPacketConn{}
var _ = Describe("Connection", func() {
var c *conn
BeforeEach(func() {
addr := &net.UDPAddr{
IP: net.IPv4(192, 168, 100, 200),
Port: 1337,
}
c = &conn{
currentAddr: addr,
pconn: &mockPacketConn{},
}
})
It("writes", func() {
err := c.write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(c.pconn.(*mockPacketConn).dataWritten.Bytes()).To(Equal([]byte("foobar")))
Expect(c.pconn.(*mockPacketConn).dataWrittenTo.String()).To(Equal("192.168.100.200:1337"))
})
It("gets the remote address", func() {
Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337"))
})
It("changes the remote address", func() {
addr := &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 7331,
}
c.setCurrentRemoteAddr(addr)
Expect(c.RemoteAddr().String()).To(Equal(addr.String()))
})
})

View file

@ -22,7 +22,7 @@ import (
type streamCreator interface { type streamCreator interface {
GetOrOpenStream(protocol.StreamID) (utils.Stream, error) GetOrOpenStream(protocol.StreamID) (utils.Stream, error)
Close(error) error Close(error) error
RemoteAddr() *net.UDPAddr RemoteAddr() net.Addr
} }
// Server is a HTTP2 server listening for QUIC connections. // Server is a HTTP2 server listening for QUIC connections.

View file

@ -36,7 +36,7 @@ func (s *mockSession) Close(e error) error {
s.closedWithError = e s.closedWithError = e
return nil return nil
} }
func (s *mockSession) RemoteAddr() *net.UDPAddr { func (s *mockSession) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42} return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
} }

View file

@ -132,7 +132,7 @@ func (s *Server) Addr() net.Addr {
return s.addr return s.addr
} }
func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet []byte) error { func (s *Server) handlePacket(pconn net.PacketConn, remoteAddr *net.UDPAddr, packet []byte) error {
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize { if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
return qerr.PacketTooLarge return qerr.PacketTooLarge
} }
@ -177,13 +177,13 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
// Send Version Negotiation Packet if the client is speaking a different protocol version // Send Version Negotiation Packet if the client is speaking a different protocol version
if hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) { if hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) {
utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber) utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber)
_, err = conn.WriteToUDP(composeVersionNegotiation(hdr.ConnectionID), remoteAddr) _, err = pconn.WriteTo(composeVersionNegotiation(hdr.ConnectionID), remoteAddr)
return err return err
} }
if !ok { if !ok {
if !hdr.VersionFlag { if !hdr.VersionFlag {
_, err = conn.WriteToUDP(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr) _, err = pconn.WriteTo(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr)
return err return err
} }
version := hdr.VersionNumber version := hdr.VersionNumber
@ -193,7 +193,7 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr) utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr)
session, err = s.newSession( session, err = s.newSession(
&udpConn{conn: conn, currentAddr: remoteAddr}, &conn{pconn: pconn, currentAddr: remoteAddr},
version, version,
hdr.ConnectionID, hdr.ConnectionID,
s.scfg, s.scfg,

View file

@ -23,7 +23,7 @@ type unpacker interface {
} }
type receivedPacket struct { type receivedPacket struct {
remoteAddr interface{} remoteAddr net.Addr
publicHeader *PublicHeader publicHeader *PublicHeader
data []byte data []byte
rcvTime time.Time rcvTime time.Time
@ -116,8 +116,14 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
session.setup() session.setup()
cryptoStream, _ := session.GetOrOpenStream(1) cryptoStream, _ := session.GetOrOpenStream(1)
var sourceAddr []byte
if udpAddr, ok := conn.RemoteAddr().(*net.UDPAddr); ok {
sourceAddr = udpAddr.IP
} else {
sourceAddr = []byte(conn.RemoteAddr().String())
}
var err error var err error
session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.RemoteAddr().IP, v, sCfg, cryptoStream, session.connectionParameters, session.aeadChanged) session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, session.connectionParameters, session.aeadChanged)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -128,9 +134,9 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
return session, err return session, err
} }
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) { func newClientSession(pconn net.PacketConn, addr net.Addr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) {
session := &Session{ session := &Session{
conn: &udpConn{conn: conn, currentAddr: addr}, conn: &conn{pconn: pconn, currentAddr: addr},
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
@ -765,7 +771,7 @@ func (s *Session) ackAlarmChanged(t time.Time) {
s.maybeResetTimer() s.maybeResetTimer()
} }
// RemoteAddr returns the net.UDPAddr of the client // RemoteAddr returns the net.Addr of the client
func (s *Session) RemoteAddr() *net.UDPAddr { func (s *Session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr() return s.conn.RemoteAddr()
} }

View file

@ -25,7 +25,7 @@ import (
) )
type mockConnection struct { type mockConnection struct {
remoteAddr net.IP remoteAddr net.Addr
written [][]byte written [][]byte
} }
@ -36,12 +36,10 @@ func (m *mockConnection) write(p []byte) error {
return nil return nil
} }
func (m *mockConnection) setCurrentRemoteAddr(addr interface{}) { func (m *mockConnection) setCurrentRemoteAddr(addr net.Addr) {
if ip, ok := addr.(net.IP); ok { m.remoteAddr = addr
m.remoteAddr = ip
}
} }
func (*mockConnection) RemoteAddr() *net.UDPAddr { return &net.UDPAddr{} } func (*mockConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} }
type mockUnpacker struct { type mockUnpacker struct {
unpackErr error unpackErr error
@ -120,22 +118,23 @@ var _ = Describe("Session", func() {
clientSession *Session clientSession *Session
streamCallbackCalled bool streamCallbackCalled bool
closeCallbackCalled bool closeCallbackCalled bool
conn *mockConnection scfg *handshake.ServerConfig
mconn *mockConnection
cpm *mockConnectionParametersManager cpm *mockConnectionParametersManager
) )
BeforeEach(func() { BeforeEach(func() {
conn = &mockConnection{} mconn = &mockConnection{}
streamCallbackCalled = false streamCallbackCalled = false
closeCallbackCalled = false closeCallbackCalled = false
certChain := crypto.NewCertChain(testdata.GetTLSConfig()) certChain := crypto.NewCertChain(testdata.GetTLSConfig())
kex, err := crypto.NewCurve25519KEX() kex, err := crypto.NewCurve25519KEX()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
scfg, err := handshake.NewServerConfig(kex, certChain) scfg, err = handshake.NewServerConfig(kex, certChain)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
pSession, err := newSession( pSession, err := newSession(
conn, mconn,
protocol.Version35, protocol.Version35,
0, 0,
scfg, scfg,
@ -163,7 +162,38 @@ var _ = Describe("Session", func() {
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream
})
Context("source address", func() {
It("uses the IP address if given an UDP connection", func() {
conn := &conn{currentAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}}
session, err := newSession(
conn,
protocol.VersionWhatever,
0,
scfg,
func(*Session, utils.Stream) { streamCallbackCalled = true },
func(protocol.ConnectionID) { closeCallbackCalled = true },
)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(session.(*Session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200}))
})
It("uses the string representation of the remote addresses if not given a UDP connection", func() {
conn := &conn{
currentAddr: &net.TCPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337},
}
session, err := newSession(
conn,
protocol.VersionWhatever,
0,
scfg,
func(*Session, utils.Stream) { streamCallbackCalled = true },
func(protocol.ConnectionID) { closeCallbackCalled = true },
)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(session.(*Session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337")))
})
}) })
Context("when handling stream frames", func() { Context("when handling stream frames", func() {
@ -617,8 +647,8 @@ var _ = Describe("Session", func() {
It("shuts down without error", func() { It("shuts down without error", func() {
session.Close(nil) session.Close(nil)
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(conn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
Expect(closeCallbackCalled).To(BeTrue()) Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
}) })
@ -627,7 +657,7 @@ var _ = Describe("Session", func() {
session.Close(nil) session.Close(nil)
session.Close(nil) session.Close(nil)
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(conn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
}) })
@ -652,7 +682,7 @@ var _ = Describe("Session", func() {
Expect(closeCallbackCalled).To(BeFalse()) Expect(closeCallbackCalled).To(BeFalse())
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(atomic.LoadUint32(&session.closed) != 0).To(BeTrue()) Expect(atomic.LoadUint32(&session.closed) != 0).To(BeTrue())
Expect(conn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
}) })
}) })
@ -712,7 +742,7 @@ var _ = Describe("Session", func() {
Context("updating the remote address", func() { Context("updating the remote address", func() {
It("sets the remote address", func() { It("sets the remote address", func() {
remoteIP := net.IPv4(192, 168, 0, 100) remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP))
p := receivedPacket{ p := receivedPacket{
remoteAddr: remoteIP, remoteAddr: remoteIP,
@ -724,8 +754,8 @@ var _ = Describe("Session", func() {
}) })
It("doesn't change the remote address if authenticating the packet fails", func() { It("doesn't change the remote address if authenticating the packet fails", func() {
remoteIP := net.IPv4(192, 168, 0, 100) remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
attackerIP := net.IPv4(192, 168, 0, 102) attackerIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 102)}
session.conn.(*mockConnection).remoteAddr = remoteIP session.conn.(*mockConnection).remoteAddr = remoteIP
// use the real packetUnpacker here, to make sure this test fails if the error code for failed decryption changes // use the real packetUnpacker here, to make sure this test fails if the error code for failed decryption changes
session.unpacker = &packetUnpacker{} session.unpacker = &packetUnpacker{}
@ -742,7 +772,7 @@ var _ = Describe("Session", func() {
It("sets the remote address, if the packet is authenticated, but unpacking fails for another reason", func() { It("sets the remote address, if the packet is authenticated, but unpacking fails for another reason", func() {
testErr := errors.New("testErr") testErr := errors.New("testErr")
remoteIP := net.IPv4(192, 168, 0, 100) remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP))
p := receivedPacket{ p := receivedPacket{
remoteAddr: remoteIP, remoteAddr: remoteIP,
@ -762,8 +792,8 @@ var _ = Describe("Session", func() {
session.receivedPacketHandler.ReceivedPacket(packetNumber, true) session.receivedPacketHandler.ReceivedPacket(packetNumber, true)
err := session.sendPacket() err := session.sendPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x5E, 0x03}))) Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x5E, 0x03})))
}) })
It("sends two WindowUpdate frames", func() { It("sends two WindowUpdate frames", func() {
@ -776,16 +806,16 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = session.sendPacket() err = session.sendPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(2)) Expect(mconn.written).To(HaveLen(2))
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0})))
Expect(conn.written[1]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) Expect(mconn.written[1]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0})))
}) })
It("sends public reset", func() { It("sends public reset", func() {
err := session.sendPublicReset(1) err := session.sendPublicReset(1)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST")))) Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
}) })
}) })
@ -808,9 +838,9 @@ var _ = Describe("Session", func() {
err := session.sendPacket() err := session.sendPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect(sph.(*mockSentPacketHandler).requestedStopWaiting).To(BeTrue()) Expect(sph.(*mockSentPacketHandler).requestedStopWaiting).To(BeTrue())
Expect(conn.written[0]).To(ContainSubstring("foobar1234567")) Expect(mconn.written[0]).To(ContainSubstring("foobar1234567"))
}) })
It("sends a StreamFrame from a packet queued for retransmission", func() { It("sends a StreamFrame from a packet queued for retransmission", func() {
@ -839,9 +869,9 @@ var _ = Describe("Session", func() {
err := session.sendPacket() err := session.sendPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring("foobar")) Expect(mconn.written[0]).To(ContainSubstring("foobar"))
Expect(conn.written[0]).To(ContainSubstring("loremipsum")) Expect(mconn.written[0]).To(ContainSubstring("loremipsum"))
}) })
It("always attaches a StopWaiting to a packet that contains a retransmission", func() { It("always attaches a StopWaiting to a packet that contains a retransmission", func() {
@ -859,7 +889,7 @@ var _ = Describe("Session", func() {
err := session.sendPacket() err := session.sendPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(2)) Expect(mconn.written).To(HaveLen(2))
sentPackets := sph.(*mockSentPacketHandler).sentPackets sentPackets := sph.(*mockSentPacketHandler).sentPackets
Expect(sentPackets).To(HaveLen(2)) Expect(sentPackets).To(HaveLen(2))
_, ok := sentPackets[0].Frames[0].(*frames.StopWaitingFrame) _, ok := sentPackets[0].Frames[0].(*frames.StopWaitingFrame)
@ -963,8 +993,8 @@ var _ = Describe("Session", func() {
go session.run() go session.run()
session.ackAlarmChanged(time.Now().Add(10 * time.Millisecond)) session.ackAlarmChanged(time.Now().Add(10 * time.Millisecond))
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
Eventually(func() int { return len(conn.written) }).ShouldNot(BeZero()) Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero())
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13}))) Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13})))
}) })
Context("bundling of small packets", func() { Context("bundling of small packets", func() {
@ -981,9 +1011,9 @@ var _ = Describe("Session", func() {
session.scheduleSending() session.scheduleSending()
go session.run() go session.run()
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(1)) Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring("foobar1")) Expect(mconn.written[0]).To(ContainSubstring("foobar1"))
Expect(conn.written[0]).To(ContainSubstring("foobar2")) Expect(mconn.written[0]).To(ContainSubstring("foobar2"))
}) })
It("sends out two big frames in two packets", func() { It("sends out two big frames in two packets", func() {
@ -999,7 +1029,7 @@ var _ = Describe("Session", func() {
}() }()
_, err = s2.Write(bytes.Repeat([]byte{'e'}, 1000)) _, err = s2.Write(bytes.Repeat([]byte{'e'}, 1000))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(2)) Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(2))
}) })
It("sends out two small frames that are written to long after one another into two packets", func() { It("sends out two small frames that are written to long after one another into two packets", func() {
@ -1008,10 +1038,10 @@ var _ = Describe("Session", func() {
go session.run() go session.run()
_, err = s.Write([]byte("foobar1")) _, err = s.Write([]byte("foobar1"))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(1)) Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
_, err = s.Write([]byte("foobar2")) _, err = s.Write([]byte("foobar2"))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(2)) Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(2))
}) })
It("sends a queued ACK frame only once", func() { It("sends a queued ACK frame only once", func() {
@ -1023,13 +1053,13 @@ var _ = Describe("Session", func() {
go session.run() go session.run()
_, err = s.Write([]byte("foobar1")) _, err = s.Write([]byte("foobar1"))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(1)) Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
_, err = s.Write([]byte("foobar2")) _, err = s.Write([]byte("foobar2"))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() [][]byte { return conn.written }).Should(HaveLen(2)) Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(2))
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13}))) Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13})))
Expect(conn.written[1]).ToNot(ContainSubstring(string([]byte{0x37, 0x13}))) Expect(mconn.written[1]).ToNot(ContainSubstring(string([]byte{0x37, 0x13})))
}) })
}) })
}) })
@ -1058,8 +1088,8 @@ var _ = Describe("Session", func() {
} }
session.run() session.run()
Expect(conn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST")))) Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
Expect(session.runClosed).To(Receive()) Expect(session.runClosed).To(Receive())
}) })
@ -1120,7 +1150,7 @@ var _ = Describe("Session", func() {
It("times out due to no network activity", func(done Done) { It("times out due to no network activity", func(done Done) {
session.lastNetworkActivityTime = time.Now().Add(-time.Hour) session.lastNetworkActivityTime = time.Now().Add(-time.Hour)
session.run() // Would normally not return session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue()) Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive()) Expect(session.runClosed).To(Receive())
close(done) close(done)
@ -1129,7 +1159,7 @@ var _ = Describe("Session", func() {
It("times out due to non-completed crypto handshake", func(done Done) { It("times out due to non-completed crypto handshake", func(done Done) {
session.sessionCreationTime = time.Now().Add(-time.Hour) session.sessionCreationTime = time.Now().Add(-time.Hour)
session.run() // Would normally not return session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time.")) Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
Expect(closeCallbackCalled).To(BeTrue()) Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive()) Expect(session.runClosed).To(Receive())
close(done) close(done)
@ -1140,7 +1170,7 @@ var _ = Describe("Session", func() {
cpm.idleTime = 99999 * time.Second cpm.idleTime = 99999 * time.Second
session.packer.connectionParameters = session.connectionParameters session.packer.connectionParameters = session.connectionParameters
session.run() // Would normally not return session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue()) Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive()) Expect(session.runClosed).To(Receive())
close(done) close(done)
@ -1153,7 +1183,7 @@ var _ = Describe("Session", func() {
cpm.idleTime = 0 * time.Millisecond cpm.idleTime = 0 * time.Millisecond
session.packer.connectionParameters = session.connectionParameters session.packer.connectionParameters = session.connectionParameters
session.run() // Would normally not return session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue()) Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive()) Expect(session.runClosed).To(Receive())
close(done) close(done)
@ -1204,8 +1234,8 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
go session.run() go session.run()
session.scheduleSending() session.scheduleSending()
Eventually(func() [][]byte { return conn.written }).ShouldNot(BeEmpty()) Eventually(func() [][]byte { return mconn.written }).ShouldNot(BeEmpty())
Expect(conn.written[0]).To(ContainSubstring("foobar")) Expect(mconn.written[0]).To(ContainSubstring("foobar"))
}) })
Context("getting streams", func() { Context("getting streams", func() {

View file

@ -1,39 +0,0 @@
package quic
import (
"net"
"sync"
)
type connection interface {
write([]byte) error
setCurrentRemoteAddr(interface{})
RemoteAddr() *net.UDPAddr
}
type udpConn struct {
mutex sync.RWMutex
conn *net.UDPConn
currentAddr *net.UDPAddr
}
var _ connection = &udpConn{}
func (c *udpConn) write(p []byte) error {
_, err := c.conn.WriteToUDP(p, c.currentAddr)
return err
}
func (c *udpConn) setCurrentRemoteAddr(addr interface{}) {
c.mutex.Lock()
c.currentAddr = addr.(*net.UDPAddr)
c.mutex.Unlock()
}
func (c *udpConn) RemoteAddr() *net.UDPAddr {
c.mutex.RLock()
addr := c.currentAddr
c.mutex.RUnlock()
return addr
}