add a callback to client that is called after the version is negotiated

This commit is contained in:
Marten Seemann 2016-12-14 18:35:33 +07:00
parent 2377b3a111
commit dc05de3312
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
3 changed files with 37 additions and 10 deletions

View file

@ -24,9 +24,14 @@ type Client struct {
version protocol.VersionNumber
versionNegotiated bool
versionNegotiateCallback VersionNegotiateCallback
session packetHandler
}
// VersionNegotiateCallback is called once the client has a negotiated version
type VersionNegotiateCallback func() error
var errHostname = errors.New("Invalid hostname")
var (
@ -35,7 +40,7 @@ var (
)
// NewClient makes a new client
func NewClient(addr string) (*Client, error) {
func NewClient(addr string, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
hostname, err := utils.HostnameFromAddr(addr)
if err != nil || len(hostname) == 0 {
return nil, errHostname
@ -62,11 +67,12 @@ func NewClient(addr string) (*Client, error) {
connectionID := protocol.ConnectionID(rand.Int63())
client := &Client{
addr: udpAddr,
conn: conn,
hostname: hostname,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
connectionID: connectionID,
addr: udpAddr,
conn: conn,
hostname: hostname,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
connectionID: connectionID,
versionNegotiateCallback: versionNegotiateCallback,
}
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version)
@ -133,6 +139,10 @@ func (c *Client) handlePacket(packet []byte) error {
// if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated {
c.versionNegotiated = true
err = c.versionNegotiateCallback()
if err != nil {
return err
}
}
if hdr.VersionFlag {
@ -151,11 +161,16 @@ func (c *Client) handlePacket(packet []byte) error {
utils.Infof("Switching to QUIC version %d", highestSupportedVersion)
c.version = highestSupportedVersion
c.versionNegotiated = true
c.session.Close(errCloseSessionForNewVersion)
err = c.createNewSession()
if err != nil {
return err
}
err = c.versionNegotiateCallback()
if err != nil {
return err
}
return nil // version negotiation packets have no payload
}

View file

@ -13,11 +13,20 @@ import (
)
var _ = Describe("Client", func() {
var client *Client
var session *mockSession
var (
client *Client
session *mockSession
versionNegotiateCallbackCalled bool
)
BeforeEach(func() {
client = &Client{}
versionNegotiateCallbackCalled = false
client = &Client{
versionNegotiateCallback: func() error {
versionNegotiateCallbackCalled = true
return nil
},
}
session = &mockSession{connectionID: 0x1337}
client.connectionID = 0x1337
client.session = session
@ -162,6 +171,7 @@ var _ = Describe("Client", func() {
err = client.handlePacket(b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(client.versionNegotiated).To(BeTrue())
Expect(versionNegotiateCallbackCalled).To(BeTrue())
})
It("changes the version after receiving a version negotiation packet", func() {
@ -172,6 +182,7 @@ var _ = Describe("Client", func() {
err := client.handlePacket(getVersionNegotiation([]protocol.VersionNumber{newVersion}))
Expect(client.version).To(Equal(newVersion))
Expect(client.versionNegotiated).To(BeTrue())
Expect(versionNegotiateCallbackCalled).To(BeTrue())
// it swapped the sessions
Expect(client.session).ToNot(Equal(session))
Expect(err).ToNot(HaveOccurred())
@ -195,6 +206,7 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
Expect(client.versionNegotiated).To(BeTrue())
Expect(session.packetCount).To(BeZero())
Expect(versionNegotiateCallbackCalled).To(BeFalse())
})
It("errors if the server should have accepted the offered version", func() {

View file

@ -10,7 +10,7 @@ func main() {
utils.SetLogLevel(utils.LogLevelDebug)
client, err := quic.NewClient(addr)
client, err := quic.NewClient(addr, func() error { return nil })
if err != nil {
panic(err)
}