diff --git a/client.go b/client.go index 4f371c8c..28e31dbe 100644 --- a/client.go +++ b/client.go @@ -16,8 +16,9 @@ import ( // A Client of QUIC type Client struct { - addr *net.UDPAddr - conn *net.UDPConn + addr *net.UDPAddr + conn *net.UDPConn + hostname string connectionID protocol.ConnectionID version protocol.VersionNumber @@ -27,6 +28,11 @@ type Client struct { var errHostname = errors.New("Invalid hostname") +var ( + errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") + errInvalidVersionNegotiation = qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.") +) + // NewClient makes a new client func NewClient(addr string) (*Client, error) { hostname, err := utils.HostnameFromAddr(addr) @@ -54,18 +60,17 @@ func NewClient(addr string) (*Client, error) { rand.Seed(time.Now().UTC().UnixNano()) connectionID := protocol.ConnectionID(rand.Int63()) - utils.Infof("Starting new connection to %s (%s), connectionID %x", host, udpAddr.String(), connectionID) - client := &Client{ addr: udpAddr, conn: conn, - version: protocol.Version36, + hostname: hostname, + version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default connectionID: connectionID, } - streamCallback := func(session *Session, stream utils.Stream) {} + utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version) - client.session, err = newClientSession(conn, udpAddr, hostname, client.version, client.connectionID, streamCallback, client.closeCallback) + err = client.createNewSession() if err != nil { return nil, err } @@ -75,8 +80,6 @@ func NewClient(addr string) (*Client, error) { // Listen listens func (c *Client) Listen() error { - go c.session.run() - for { data := getPacketBuffer() data = data[:protocol.MaxPacketSize] @@ -120,6 +123,32 @@ func (c *Client) handlePacket(packet []byte) error { } hdr.Raw = packet[:len(packet)-r.Len()] + // TODO: ignore delayed / duplicated version negotiation packets + + if hdr.VersionFlag { + // check if the server sent the offered version in supported versions + for _, v := range hdr.SupportedVersions { + if v == c.version { + return errInvalidVersionNegotiation + } + } + + ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions) + if !ok { + return qerr.VersionNegotiationMismatch + } + + utils.Infof("Switching to QUIC version %d", highestSupportedVersion) + c.version = highestSupportedVersion + c.session.Close(errCloseSessionForNewVersion) + err = c.createNewSession() + if err != nil { + return err + } + + return nil // version negotiation packets have no payload + } + c.session.handlePacket(&receivedPacket{ remoteAddr: c.addr, publicHeader: hdr, @@ -129,6 +158,19 @@ func (c *Client) handlePacket(packet []byte) error { return nil } +func (c *Client) createNewSession() error { + var err error + c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback) + if err != nil { + return err + } + + go c.session.run() + return nil +} + +func (c *Client) streamCallback(session *Session, stream utils.Stream) {} + func (c *Client) closeCallback(id protocol.ConnectionID) { utils.Infof("Connection %x closed.", id) } diff --git a/client_test.go b/client_test.go index 0db467d1..6ba4b288 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "encoding/binary" "net" "github.com/lucas-clemente/quic-go/protocol" @@ -17,10 +18,20 @@ var _ = Describe("Client", func() { BeforeEach(func() { client = &Client{} - session = &mockSession{} + session = &mockSession{connectionID: 0x1337} + client.connectionID = 0x1337 client.session = session + client.version = protocol.Version36 }) + startUDPConn := func() { + var err error + client.addr, err = net.ResolveUDPAddr("udp", "127.0.0.1:0") + Expect(err).ToNot(HaveOccurred()) + client.conn, err = net.ListenUDP("udp", client.addr) + Expect(err).NotTo(HaveOccurred()) + } + It("errors on invalid public header", func() { err := client.handlePacket(nil) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader)) @@ -42,6 +53,17 @@ var _ = Describe("Client", func() { Expect(session.closeReason).To(BeNil()) }) + It("creates new sessions with the right parameters", func() { + startUDPConn() + client.session = nil + client.hostname = "hostname" + err := client.createNewSession() + Expect(err).ToNot(HaveOccurred()) + Expect(client.session).ToNot(BeNil()) + Expect(client.session.(*Session).connectionID).To(Equal(client.connectionID)) + Expect(client.session.(*Session).version).To(Equal(client.version)) + }) + Context("handling packets", func() { It("errors on too large packets", func() { err := client.handlePacket(bytes.Repeat([]byte{'f'}, int(protocol.MaxPacketSize+1))) @@ -49,11 +71,7 @@ var _ = Describe("Client", func() { }) It("handles packets", func(done Done) { - var err error - client.addr, err = net.ResolveUDPAddr("udp", "127.0.0.1:0") - Expect(err).ToNot(HaveOccurred()) - client.conn, err = net.ListenUDP("udp", client.addr) - Expect(err).NotTo(HaveOccurred()) + startUDPConn() serverConn, err := net.DialUDP("udp", nil, client.conn.LocalAddr().(*net.UDPAddr)) Expect(err).NotTo(HaveOccurred()) @@ -84,11 +102,7 @@ var _ = Describe("Client", func() { }) It("closes the session when encountering an error while handling a packet", func(done Done) { - var err error - client.addr, err = net.ResolveUDPAddr("udp", "127.0.0.1:0") - Expect(err).ToNot(HaveOccurred()) - client.conn, err = net.ListenUDP("udp", client.addr) - Expect(err).NotTo(HaveOccurred()) + startUDPConn() serverConn, err := net.DialUDP("udp", nil, client.conn.LocalAddr().(*net.UDPAddr)) Expect(err).NotTo(HaveOccurred()) @@ -111,4 +125,49 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) }) }) + + Context("version negotiation", func() { + getVersionNegotiation := func(versions []protocol.VersionNumber) []byte { + oldVersionNegotiationPacket := composeVersionNegotiation(0x1337) + oldSupportVersionTags := protocol.SupportedVersionsAsTags + var b bytes.Buffer + for _, v := range versions { + s := make([]byte, 4) + binary.LittleEndian.PutUint32(s, protocol.VersionNumberToTag(v)) + b.Write(s) + } + protocol.SupportedVersionsAsTags = b.Bytes() + packet := composeVersionNegotiation(client.connectionID) + protocol.SupportedVersionsAsTags = oldSupportVersionTags + Expect(composeVersionNegotiation(0x1337)).To(Equal(oldVersionNegotiationPacket)) + return packet + } + + It("changes the version after receiving a version negotiation packet", func() { + startUDPConn() + newVersion := protocol.Version35 + Expect(newVersion).ToNot(Equal(client.version)) + Expect(session.packetCount).To(BeZero()) + err := client.handlePacket(getVersionNegotiation([]protocol.VersionNumber{newVersion})) + Expect(client.version).To(Equal(newVersion)) + // it swapped the sessions + Expect(client.session).ToNot(Equal(session)) + Expect(err).ToNot(HaveOccurred()) + // it didn't pass the version negoation packet to the session (since it has no payload) + Expect(session.packetCount).To(BeZero()) + + err = client.Close() + Expect(err).ToNot(HaveOccurred()) + }) + + It("errors if no matching version is found", func() { + err := client.handlePacket(getVersionNegotiation([]protocol.VersionNumber{1})) + Expect(err).To(MatchError(qerr.VersionNegotiationMismatch)) + }) + + It("errors if the server should have accepted the offered version", func() { + err := client.handlePacket(getVersionNegotiation([]protocol.VersionNumber{client.version})) + Expect(err).To(MatchError(errInvalidVersionNegotiation)) + }) + }) }) diff --git a/session.go b/session.go index 8633eb52..434067d4 100644 --- a/session.go +++ b/session.go @@ -480,6 +480,15 @@ func (s *Session) closeImpl(e error, remoteClose bool) error { return errSessionAlreadyClosed } + if e == errCloseSessionForNewVersion { + s.closeStreamsWithError(e) + // when the run loop exits, it will call the closeCallback + // replace it with an noop function to make sure this doesn't have any effect + s.closeCallback = func(protocol.ConnectionID) {} + s.closeChan <- nil + return nil + } + if e == nil { e = qerr.PeerGoingAway } diff --git a/session_test.go b/session_test.go index 0e9d2d34..fd511651 100644 --- a/session_test.go +++ b/session_test.go @@ -590,6 +590,14 @@ var _ = Describe("Session", func() { Expect(err.Error()).To(ContainSubstring(testErr.Error())) Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() }) + + It("closes the session in order to replace it with another QUIC version", func() { + session.Close(errCloseSessionForNewVersion) + Expect(closeCallbackCalled).To(BeFalse()) + Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) + Expect(atomic.LoadUint32(&session.closed) != 0).To(BeTrue()) + Expect(conn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent + }) }) Context("receiving packets", func() {