mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
We were relying on the number of goroutines in those tests. This commit replaces that logic with grepping through the current callstack. The 'doesn't return after version negotiation is established if no ConnState is defined' test is disabled, since I can't figure out a way to make it work without leaking sessions. I filed #506 to track that work. Fixes #484.
270 lines
9.8 KiB
Go
270 lines
9.8 KiB
Go
package quic
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"net"
|
|
"reflect"
|
|
"unsafe"
|
|
|
|
"github.com/lucas-clemente/quic-go/protocol"
|
|
"github.com/lucas-clemente/quic-go/qerr"
|
|
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
var _ = Describe("Client", func() {
|
|
var (
|
|
cl *client
|
|
config *Config
|
|
sess *mockSession
|
|
packetConn *mockPacketConn
|
|
addr net.Addr
|
|
versionNegotiateConnStateCalled bool
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
Eventually(areSessionsRunning).Should(BeFalse())
|
|
versionNegotiateConnStateCalled = false
|
|
packetConn = &mockPacketConn{}
|
|
config = &Config{
|
|
ConnState: func(_ Session, state ConnState) {
|
|
if state == ConnStateVersionNegotiated {
|
|
versionNegotiateConnStateCalled = true
|
|
}
|
|
},
|
|
}
|
|
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
|
sess = &mockSession{connectionID: 0x1337}
|
|
cl = &client{
|
|
config: config,
|
|
connectionID: 0x1337,
|
|
session: sess,
|
|
version: protocol.Version36,
|
|
conn: &conn{pconn: packetConn, currentAddr: addr},
|
|
}
|
|
})
|
|
|
|
AfterEach(func() {
|
|
if s, ok := cl.session.(*session); ok {
|
|
s.Close(nil)
|
|
}
|
|
Eventually(areSessionsRunning).Should(BeFalse())
|
|
})
|
|
|
|
Context("Dialing", func() {
|
|
It("creates a new client", func() {
|
|
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
|
|
var err error
|
|
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
|
|
Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io"))
|
|
sess.Close(nil)
|
|
})
|
|
|
|
It("errors when receiving an invalid first packet from the server", func() {
|
|
packetConn.dataToRead = []byte{0xff}
|
|
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(sess).To(BeNil())
|
|
})
|
|
|
|
It("errors when receiving an error from the connection", func() {
|
|
testErr := errors.New("connection error")
|
|
packetConn.readErr = testErr
|
|
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
|
|
// now we're only testing that Dial doesn't return directly after version negotiation
|
|
PIt("doesn't return after version negotiation is established if no ConnState is defined", func() {
|
|
// TODO(#506): Fix test
|
|
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
|
|
config.ConnState = nil
|
|
var dialReturned bool
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
dialReturned = true
|
|
}()
|
|
Consistently(func() bool { return dialReturned }).Should(BeFalse())
|
|
})
|
|
|
|
It("only establishes a connection once it is forward-secure if no ConnState is defined", func() {
|
|
config.ConnState = nil
|
|
client := &client{conn: &conn{pconn: packetConn, currentAddr: addr}, config: config}
|
|
client.connStateChangeOrErrCond.L = &client.mutex
|
|
var returned bool
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
_, err := client.establishConnection()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
returned = true
|
|
}()
|
|
Consistently(func() bool { return returned }).Should(BeFalse())
|
|
// switch to a secure connection
|
|
client.cryptoChangeCallback(nil, false)
|
|
Consistently(func() bool { return returned }).Should(BeFalse())
|
|
// switch to a forward-secure connection
|
|
client.cryptoChangeCallback(nil, true)
|
|
Eventually(func() bool { return returned }).Should(BeTrue())
|
|
})
|
|
})
|
|
|
|
It("errors on invalid public header", func() {
|
|
err := cl.handlePacket(nil, nil)
|
|
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
|
|
})
|
|
|
|
// this test requires a real session (because it calls the close callback) and a real UDP conn (because it unblocks and errors when it is closed)
|
|
It("properly closes", func(done Done) {
|
|
Eventually(areSessionsRunning).Should(BeFalse())
|
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
cl.conn = &conn{pconn: udpConn}
|
|
err = cl.createNewSession(nil)
|
|
Expect(err).NotTo(HaveOccurred())
|
|
testErr := errors.New("test error")
|
|
Eventually(areSessionsRunning).Should(BeTrue())
|
|
|
|
var stoppedListening bool
|
|
go func() {
|
|
cl.listen()
|
|
stoppedListening = true
|
|
}()
|
|
|
|
err = cl.session.Close(testErr)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Eventually(func() bool { return stoppedListening }).Should(BeTrue())
|
|
Eventually(areSessionsRunning).Should(BeFalse())
|
|
close(done)
|
|
}, 10)
|
|
|
|
It("creates new sessions with the right parameters", func() {
|
|
cl.session = nil
|
|
cl.hostname = "hostname"
|
|
err := cl.createNewSession(nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(cl.session).ToNot(BeNil())
|
|
Expect(cl.session.(*session).connectionID).To(Equal(cl.connectionID))
|
|
Expect(cl.session.(*session).version).To(Equal(cl.version))
|
|
})
|
|
|
|
Context("handling packets", func() {
|
|
It("handles packets", func() {
|
|
ph := PublicHeader{
|
|
PacketNumber: 1,
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
ConnectionID: 0x1337,
|
|
}
|
|
b := &bytes.Buffer{}
|
|
err := ph.Write(b, protocol.Version36, protocol.PerspectiveServer)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
packetConn.dataToRead = b.Bytes()
|
|
|
|
Expect(sess.packetCount).To(BeZero())
|
|
var stoppedListening bool
|
|
go func() {
|
|
cl.listen()
|
|
// it should continue listening when receiving valid packets
|
|
stoppedListening = true
|
|
}()
|
|
|
|
Eventually(func() int { return sess.packetCount }).Should(Equal(1))
|
|
Expect(sess.closed).To(BeFalse())
|
|
Consistently(func() bool { return stoppedListening }).Should(BeFalse())
|
|
})
|
|
|
|
It("closes the session when encountering an error while handling a packet", func() {
|
|
Expect(sess.closeReason).ToNot(HaveOccurred())
|
|
packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100)
|
|
cl.listen()
|
|
Expect(sess.closed).To(BeTrue())
|
|
Expect(sess.closeReason).To(HaveOccurred())
|
|
})
|
|
|
|
It("closes the session when encountering an error while reading from the connection", func() {
|
|
testErr := errors.New("test error")
|
|
packetConn.readErr = testErr
|
|
cl.listen()
|
|
Expect(sess.closed).To(BeTrue())
|
|
Expect(sess.closeReason).To(MatchError(testErr))
|
|
})
|
|
})
|
|
|
|
Context("version negotiation", func() {
|
|
getVersionNegotiation := func(versions []protocol.VersionNumber) []byte {
|
|
oldVersionNegotiationPacket := composeVersionNegotiation(0x1337)
|
|
oldSupportVersionTags := protocol.SupportedVersionsAsTags
|
|
var b bytes.Buffer
|
|
for _, v := range versions {
|
|
s := make([]byte, 4)
|
|
binary.LittleEndian.PutUint32(s, protocol.VersionNumberToTag(v))
|
|
b.Write(s)
|
|
}
|
|
protocol.SupportedVersionsAsTags = b.Bytes()
|
|
packet := composeVersionNegotiation(cl.connectionID)
|
|
protocol.SupportedVersionsAsTags = oldSupportVersionTags
|
|
Expect(composeVersionNegotiation(0x1337)).To(Equal(oldVersionNegotiationPacket))
|
|
return packet
|
|
}
|
|
|
|
It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() {
|
|
ph := PublicHeader{
|
|
PacketNumber: 1,
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
ConnectionID: 0x1337,
|
|
}
|
|
b := &bytes.Buffer{}
|
|
err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
err = cl.handlePacket(nil, b.Bytes())
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
|
|
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
|
|
})
|
|
|
|
It("changes the version after receiving a version negotiation packet", func() {
|
|
newVersion := protocol.Version35
|
|
Expect(newVersion).ToNot(Equal(cl.version))
|
|
Expect(sess.packetCount).To(BeZero())
|
|
cl.connectionID = 0x1337
|
|
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion}))
|
|
Expect(cl.version).To(Equal(newVersion))
|
|
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
|
|
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
|
|
// it swapped the sessions
|
|
Expect(cl.session).ToNot(Equal(sess))
|
|
Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// it didn't pass the version negoation packet to the session (since it has no payload)
|
|
Expect(sess.packetCount).To(BeZero())
|
|
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35}))
|
|
})
|
|
|
|
It("errors if no matching version is found", func() {
|
|
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
|
|
Expect(err).To(MatchError(qerr.InvalidVersion))
|
|
})
|
|
|
|
It("ignores delayed version negotiation packets", func() {
|
|
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
|
|
cl.connState = ConnStateVersionNegotiated
|
|
Expect(sess.packetCount).To(BeZero())
|
|
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
|
|
Expect(sess.packetCount).To(BeZero())
|
|
Consistently(func() bool { return versionNegotiateConnStateCalled }).Should(BeFalse())
|
|
})
|
|
|
|
It("errors if the server should have accepted the offered version", func() {
|
|
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{cl.version}))
|
|
Expect(err).To(MatchError(qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.")))
|
|
})
|
|
})
|
|
})
|