diff --git a/client.go b/client.go index 1906abdf..3bf95b13 100644 --- a/client.go +++ b/client.go @@ -44,7 +44,7 @@ type client struct { var ( // make it possible to mock connection ID generation in the tests - generateConnectionID = utils.GenerateConnectionID + generateConnectionID = protocol.GenerateConnectionID errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") ) @@ -107,7 +107,7 @@ func Dial( logger: utils.DefaultLogger, } - c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) + c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %s, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) if err := c.dial(); err != nil { return nil, err @@ -240,7 +240,7 @@ func (c *client) establishSecureConnection() error { go func() { runErr = c.session.run() // returns as soon as the session is closed close(errorChan) - c.logger.Infof("Connection %x closed.", c.connectionID) + c.logger.Infof("Connection %s closed.", c.connectionID) if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { c.conn.Close() } @@ -304,11 +304,16 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { } hdr.Raw = packet[:len(packet)-r.Len()] + if hdr.IsLongHeader && !hdr.DestConnectionID.Equal(hdr.SrcConnectionID) { + c.logger.Errorf("receiving packets with different destination and source connection IDs not supported") + } + c.mutex.Lock() defer c.mutex.Unlock() // reject packets with the wrong connection ID - if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID { + // TODO(#1003): add support for server-chosen connection IDs + if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.connectionID) { return } @@ -316,7 +321,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { cr := c.conn.RemoteAddr() // check if the remote address and the connection ID match // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection - if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID { + if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.connectionID) { c.logger.Infof("Received a spoofed Public Reset. Ignoring.") return } @@ -384,11 +389,11 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { c.initialVersion = c.version c.version = newVersion var err error - c.connectionID, err = utils.GenerateConnectionID() + c.connectionID, err = protocol.GenerateConnectionID() if err != nil { return err } - c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) + c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.connectionID) c.session.Close(errCloseSessionForNewVersion) return nil } diff --git a/client_test.go b/client_test.go index 06665305..b816fb17 100644 --- a/client_test.go +++ b/client_test.go @@ -25,6 +25,7 @@ var _ = Describe("Client", func() { sess *mockSession packetConn *mockPacketConn addr net.Addr + connID protocol.ConnectionID originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, logger utils.Logger) (packetHandler, error) ) @@ -33,25 +34,27 @@ var _ = Describe("Client", func() { acceptClientVersionPacket := func(connID protocol.ConnectionID) []byte { b := &bytes.Buffer{} err := (&wire.Header{ - ConnectionID: connID, - PacketNumber: 1, - PacketNumberLen: 1, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 1, + PacketNumberLen: 1, }).Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) return b.Bytes() } BeforeEach(func() { + connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) - msess, _ := newMockSession(nil, 0, 0, nil, nil, nil, nil) + msess, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil) sess = msess.(*mockSession) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} packetConn = newMockPacketConn() packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} packetConn.dataReadFrom = addr cl = &client{ - connectionID: 0x1337, + connectionID: connID, session: sess, version: protocol.SupportedVersions[0], conn: &conn{pconn: packetConn, currentAddr: addr}, @@ -291,9 +294,10 @@ var _ = Describe("Client", func() { It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() { ph := wire.Header{ - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, - ConnectionID: 0x1337, + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen2, + DestConnectionID: connID, + SrcConnectionID: connID, } b := &bytes.Buffer{} err := ph.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) @@ -388,15 +392,15 @@ var _ = Describe("Client", func() { go cl.dial() Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1)) cl.config = &Config{Versions: []protocol.VersionNumber{77, 78}} - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{77})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{77})) Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{78})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{78})) Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) }) It("errors if no matching version is found", func() { cl.config = &Config{Versions: protocol.SupportedVersions} - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{1})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1})) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) @@ -405,7 +409,7 @@ var _ = Describe("Client", func() { v := protocol.VersionNumber(1234) Expect(v).ToNot(Equal(cl.version)) cl.config = &Config{Versions: protocol.SupportedVersions} - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{v})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v})) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) @@ -413,7 +417,7 @@ var _ = Describe("Client", func() { It("changes to the version preferred by the quic.Config", func() { config := &Config{Versions: []protocol.VersionNumber{1234, 4321}} cl.config = config - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{4321, 1234})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234})) Expect(cl.version).To(Equal(protocol.VersionNumber(1234))) }) @@ -421,14 +425,14 @@ var _ = Describe("Client", func() { // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test cl.versionNegotiated = true Expect(sess.packetCount).To(BeZero()) - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{1})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1})) Expect(cl.versionNegotiated).To(BeTrue()) Expect(sess.packetCount).To(BeZero()) }) It("drops version negotiation packets that contain the offered version", func() { ver := cl.version - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{ver})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver})) Expect(cl.version).To(Equal(ver)) }) }) @@ -455,10 +459,13 @@ var _ = Describe("Client", func() { It("ignores packets with the wrong connection ID", func() { buf := &bytes.Buffer{} + connID2 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} + Expect(connID).ToNot(Equal(connID2)) (&wire.Header{ - ConnectionID: cl.connectionID + 1, - PacketNumber: 1, - PacketNumberLen: 1, + DestConnectionID: connID2, + SrcConnectionID: connID2, + PacketNumber: 1, + PacketNumberLen: 1, }).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever) cl.handlePacket(addr, buf.Bytes()) Expect(sess.packetCount).To(BeZero()) @@ -585,9 +592,10 @@ var _ = Describe("Client", func() { Context("handling packets", func() { It("handles packets", func() { ph := wire.Header{ - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, - ConnectionID: 0x1337, + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen2, + DestConnectionID: connID, + SrcConnectionID: connID, } b := &bytes.Buffer{} err := ph.Write(b, protocol.PerspectiveServer, cl.version) @@ -625,7 +633,9 @@ var _ = Describe("Client", func() { }) It("ignores Public Resets with the wrong connection ID", func() { - cl.handlePacket(addr, wire.WritePublicReset(cl.connectionID+1, 1, 0)) + connID2 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} + Expect(connID).ToNot(Equal(connID2)) + cl.handlePacket(addr, wire.WritePublicReset(connID2, 1, 0)) Expect(cl.session.(*mockSession).closed).To(BeFalse()) Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) }) diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index 3d2a1f78..f1a010ba 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -25,7 +25,8 @@ var _ = Describe("QUIC Proxy", func() { hdr := wire.Header{ PacketNumber: p, PacketNumberLen: protocol.PacketNumberLen6, - ConnectionID: 1337, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, OmitConnectionID: false, } hdr.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) diff --git a/internal/crypto/key_derivation_quic_crypto.go b/internal/crypto/key_derivation_quic_crypto.go index 28f6c2cc..6c294178 100644 --- a/internal/crypto/key_derivation_quic_crypto.go +++ b/internal/crypto/key_derivation_quic_crypto.go @@ -6,7 +6,6 @@ import ( "io" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" "golang.org/x/crypto/hkdf" ) @@ -42,7 +41,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol } else { info.Write([]byte("QUIC key expansion\x00")) } - utils.BigEndian.WriteUint64(&info, uint64(connID)) + info.Write(connID) info.Write(chlo) info.Write(scfg) info.Write(cert) diff --git a/internal/crypto/key_derivation_quic_crypto_test.go b/internal/crypto/key_derivation_quic_crypto_test.go index 41ec20fa..d866121c 100644 --- a/internal/crypto/key_derivation_quic_crypto_test.go +++ b/internal/crypto/key_derivation_quic_crypto_test.go @@ -92,7 +92,7 @@ var _ = Describe("QUIC Crypto Key Derivation", func() { false, []byte("0123456789012345678901"), []byte("nonce"), - protocol.ConnectionID(0x2a00000000000000), // this was 42 before the connection ID was changed to big endian + protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}), []byte("chlo"), []byte("scfg"), []byte("cert"), @@ -111,7 +111,7 @@ var _ = Describe("QUIC Crypto Key Derivation", func() { false, []byte("0123456789012345678901"), []byte("nonce"), - protocol.ConnectionID(42), + protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}), []byte("chlo"), []byte("scfg"), []byte("cert"), @@ -123,7 +123,7 @@ var _ = Describe("QUIC Crypto Key Derivation", func() { false, []byte("0123456789012345678901"), []byte("nonce"), - protocol.ConnectionID(42), + protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}), []byte("chlo"), []byte("scfg"), []byte("cert"), @@ -142,7 +142,7 @@ var _ = Describe("QUIC Crypto Key Derivation", func() { false, []byte("0123456789012345678901"), []byte("nonce"), - protocol.ConnectionID(0x2a00000000000000), // this was 42 before the connection ID was changed to big endian + protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}), []byte("chlo"), []byte("scfg"), []byte("cert"), @@ -161,7 +161,7 @@ var _ = Describe("QUIC Crypto Key Derivation", func() { true, []byte("0123456789012345678901"), []byte("nonce"), - protocol.ConnectionID(0x2a00000000000000), // this was 42 before the connection ID was changed to big endian + protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}), []byte("chlo"), []byte("scfg"), []byte("cert"), @@ -180,7 +180,7 @@ var _ = Describe("QUIC Crypto Key Derivation", func() { true, []byte("0123456789012345678901"), []byte("nonce"), - protocol.ConnectionID(0x2a00000000000000), // this was 42 before the connection ID was changed to big endian + protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}), []byte("chlo"), []byte("scfg"), []byte("cert"), diff --git a/internal/crypto/null_aead_aesgcm.go b/internal/crypto/null_aead_aesgcm.go index 92c6362f..4abc6229 100644 --- a/internal/crypto/null_aead_aesgcm.go +++ b/internal/crypto/null_aead_aesgcm.go @@ -2,7 +2,6 @@ package crypto import ( "crypto" - "encoding/binary" "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -28,9 +27,7 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) } -func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) { - connID := make([]byte, 8) - binary.BigEndian.PutUint64(connID, uint64(connectionID)) +func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID) clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size()) serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size()) diff --git a/internal/crypto/null_aead_aesgcm_test.go b/internal/crypto/null_aead_aesgcm_test.go index d56641ed..8f45a956 100644 --- a/internal/crypto/null_aead_aesgcm_test.go +++ b/internal/crypto/null_aead_aesgcm_test.go @@ -9,7 +9,7 @@ import ( var _ = Describe("NullAEAD using AES-GCM", func() { // values taken from https://github.com/quicwg/base-drafts/wiki/Test-Vector-for-the-Clear-Text-AEAD-key-derivation Context("using the test vector from the QUIC WG Wiki", func() { - connID := protocol.ConnectionID(0x8394c8f03e515708) + connID := protocol.ConnectionID([]byte{0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08}) It("computes the secrets", func() { clientSecret, serverSecret := computeSecrets(connID) @@ -55,7 +55,7 @@ var _ = Describe("NullAEAD using AES-GCM", func() { }) It("seals and opens", func() { - connectionID := protocol.ConnectionID(0x1234567890) + connectionID := protocol.ConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}) clientAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) serverAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveServer) @@ -72,9 +72,11 @@ var _ = Describe("NullAEAD using AES-GCM", func() { }) It("doesn't work if initialized with different connection IDs", func() { - clientAEAD, err := newNullAEADAESGCM(1, protocol.PerspectiveClient) + c1 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1}) + c2 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2}) + clientAEAD, err := newNullAEADAESGCM(c1, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) - serverAEAD, err := newNullAEADAESGCM(2, protocol.PerspectiveServer) + serverAEAD, err := newNullAEADAESGCM(c2, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) clientMessage := clientAEAD.Seal(nil, []byte("foobar"), 42, []byte("aad")) diff --git a/internal/crypto/null_aead_test.go b/internal/crypto/null_aead_test.go index 73de963c..ce3a12a0 100644 --- a/internal/crypto/null_aead_test.go +++ b/internal/crypto/null_aead_test.go @@ -8,7 +8,7 @@ import ( var _ = Describe("NullAEAD", func() { It("selects the right FVN variant", func() { - connID := protocol.ConnectionID(0x42) + connID := protocol.ConnectionID([]byte{0x42, 0, 0, 0, 0, 0, 0, 0}) Expect(NewNullAEAD(protocol.PerspectiveClient, connID, protocol.Version39)).To(Equal(&nullAEADFNV128a{ perspective: protocol.PerspectiveClient, })) diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index 0ab8beb8..c3e412cc 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -123,7 +123,7 @@ var _ = Describe("Client Crypto Setup", func() { csInt, dnc, err := NewCryptoSetupClient( stream, "hostname", - 0, + protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, version, nil, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index 2fdf053b..15624f84 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -161,7 +161,7 @@ var _ = Describe("Server Crypto Setup", func() { supportedVersions = []protocol.VersionNumber{version, 98, 99} csInt, err := NewCryptoSetup( stream, - protocol.ConnectionID(42), + protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, remoteAddr, version, make([]byte, 32), // div nonce diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 67bc2fda..84375a7c 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -200,7 +200,7 @@ var _ = Describe("TLS Crypto Setup, for the client", func() { handshakeEvent = make(chan struct{}) csInt, err := NewCryptoSetupTLSClient( nil, - 0, + protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, "quic.clemente.io", handshakeEvent, nil, // mintTLS diff --git a/internal/handshake/tls_extension.go b/internal/handshake/tls_extension.go index 98ad3a57..1e6b18ef 100644 --- a/internal/handshake/tls_extension.go +++ b/internal/handshake/tls_extension.go @@ -13,7 +13,6 @@ const ( initialMaxDataParameterID transportParameterID = 0x1 initialMaxStreamsBiDiParameterID transportParameterID = 0x2 idleTimeoutParameterID transportParameterID = 0x3 - omitConnectionIDParameterID transportParameterID = 0x4 maxPacketSizeParameterID transportParameterID = 0x5 statelessResetTokenParameterID transportParameterID = 0x6 initialMaxStreamsUniParameterID transportParameterID = 0x8 diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index 44f38355..8e365bb1 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -116,10 +116,9 @@ var _ = Describe("Transport Parameters", func() { ConnectionFlowControlWindow: 0x4321, MaxBidiStreams: 1337, MaxUniStreams: 7331, - OmitConnectionID: true, IdleTimeout: 42 * time.Second, } - Expect(p.String()).To(Equal("&handshake.TransportParameters{StreamFlowControlWindow: 0x1234, ConnectionFlowControlWindow: 0x4321, MaxBidiStreams: 1337, MaxUniStreams: 7331, OmitConnectionID: true, IdleTimeout: 42s}")) + Expect(p.String()).To(Equal("&handshake.TransportParameters{StreamFlowControlWindow: 0x1234, ConnectionFlowControlWindow: 0x4321, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}")) }) Context("parsing", func() { @@ -147,13 +146,6 @@ var _ = Describe("Transport Parameters", func() { Expect(params.MaxPacketSize).To(Equal(protocol.ByteCount(0x7331))) }) - It("saves if it should omit the connection ID", func() { - parameters[omitConnectionIDParameterID] = []byte{} - params, err := readTransportParameters(paramsMapToList(parameters)) - Expect(err).ToNot(HaveOccurred()) - Expect(params.OmitConnectionID).To(BeTrue()) - }) - It("rejects the parameters if the initial_max_stream_data is missing", func() { delete(parameters, initialMaxStreamDataParameterID) _, err := readTransportParameters(paramsMapToList(parameters)) @@ -211,12 +203,6 @@ var _ = Describe("Transport Parameters", func() { Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)")) }) - It("rejects the parameters if omit_connection_id is non-empty", func() { - parameters[omitConnectionIDParameterID] = []byte{0} // should be empty - _, err := readTransportParameters(paramsMapToList(parameters)) - Expect(err).To(MatchError("wrong length for omit_connection_id: 1 (expected empty)")) - }) - It("rejects the parameters if max_packet_size has the wrong length", func() { parameters[maxPacketSizeParameterID] = []byte{0x11} // should be 2 bytes _, err := readTransportParameters(paramsMapToList(parameters)) @@ -267,12 +253,6 @@ var _ = Describe("Transport Parameters", func() { Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0xca, 0xfe})) Expect(values).To(HaveKeyWithValue(maxPacketSizeParameterID, []byte{0x5, 0xac})) // 1452 = 0x5ac }) - - It("request ommission of the connection ID", func() { - params.OmitConnectionID = true - values := paramsListToMap(params.getTransportParameters()) - Expect(values).To(HaveKeyWithValue(omitConnectionIDParameterID, []byte{})) - }) }) }) }) diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index fce1e3f2..7a224377 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -26,7 +26,7 @@ type TransportParameters struct { MaxBidiStreams uint16 // only used for IETF QUIC MaxStreams uint32 // only used for gQUIC - OmitConnectionID bool + OmitConnectionID bool // only used for gQUIC IdleTimeout time.Duration } @@ -132,11 +132,6 @@ func readTransportParameters(paramsList []transportParameter) (*TransportParamet return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value)) } params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second) - case omitConnectionIDParameterID: - if len(p.Value) != 0 { - return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value)) - } - params.OmitConnectionID = true case maxPacketSizeParameterID: if len(p.Value) != 2 { return nil, fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", len(p.Value)) @@ -178,14 +173,11 @@ func (p *TransportParameters) getTransportParameters() []transportParameter { {idleTimeoutParameterID, idleTimeout}, {maxPacketSizeParameterID, maxPacketSize}, } - if p.OmitConnectionID { - params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}}) - } return params } // String returns a string representation, intended for logging. // It should only used for IETF QUIC. func (p *TransportParameters) String() string { - return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, OmitConnectionID: %t, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.OmitConnectionID, p.IdleTimeout) + return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout) } diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go new file mode 100644 index 00000000..918631f9 --- /dev/null +++ b/internal/protocol/connection_id.go @@ -0,0 +1,56 @@ +package protocol + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" +) + +// A ConnectionID in QUIC +type ConnectionID []byte + +// GenerateConnectionID generates a connection ID using cryptographic random +func GenerateConnectionID() (ConnectionID, error) { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return nil, err + } + return ConnectionID(b), nil +} + +// ReadConnectionID reads a connection ID of length len from the given io.Reader. +// It returns io.EOF if there are not enough bytes to read. +func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { + if len == 0 { + return nil, nil + } + c := make(ConnectionID, len) + _, err := io.ReadFull(r, c) + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return c, err +} + +// Equal says if two connection IDs are equal +func (c ConnectionID) Equal(other ConnectionID) bool { + return bytes.Equal(c, other) +} + +// Len returns the length of the connection ID in bytes +func (c ConnectionID) Len() int { + return len(c) +} + +// Bytes returns the byte representation +func (c ConnectionID) Bytes() []byte { + return []byte(c) +} + +func (c ConnectionID) String() string { + if c.Len() == 0 { + return "(empty)" + } + return fmt.Sprintf("%#x", c.Bytes()) +} diff --git a/internal/protocol/connection_id_test.go b/internal/protocol/connection_id_test.go new file mode 100644 index 00000000..9f7d17de --- /dev/null +++ b/internal/protocol/connection_id_test.go @@ -0,0 +1,84 @@ +package protocol + +import ( + "bytes" + "io" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Connection ID generation", func() { + It("generates random connection IDs", func() { + c1, err := GenerateConnectionID() + Expect(err).ToNot(HaveOccurred()) + Expect(c1).ToNot(BeZero()) + c2, err := GenerateConnectionID() + Expect(err).ToNot(HaveOccurred()) + Expect(c1).ToNot(Equal(c2)) + }) + + It("says if connection IDs are equal", func() { + c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + Expect(c1.Equal(c1)).To(BeTrue()) + Expect(c2.Equal(c2)).To(BeTrue()) + Expect(c1.Equal(c2)).To(BeFalse()) + Expect(c2.Equal(c1)).To(BeFalse()) + }) + + It("reads the connection ID", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) + c, err := ReadConnectionID(buf, 9) + Expect(err).ToNot(HaveOccurred()) + Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})) + }) + + It("returns io.EOF if there's not enough data to read", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) + _, err := ReadConnectionID(buf, 5) + Expect(err).To(MatchError(io.EOF)) + }) + + It("returns nil for a 0 length connection ID", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) + c, err := ReadConnectionID(buf, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(c).To(BeNil()) + }) + + It("returns the length", func() { + c := ConnectionID{1, 2, 3, 4, 5, 6, 7} + Expect(c.Len()).To(Equal(7)) + }) + + It("has 0 length for the default value", func() { + var c ConnectionID + Expect(c.Len()).To(BeZero()) + }) + + It("returns the bytes", func() { + c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) + Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7})) + }) + + It("returns a nil byte slice for the default value", func() { + var c ConnectionID + Expect(c.Bytes()).To(BeNil()) + }) + + It("has a string representation", func() { + c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) + Expect(c.String()).To(Equal("0xdeadbeef42")) + }) + + It("has a long string representation", func() { + c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad} + Expect(c.String()).To(Equal("0x13370000decafbad")) + }) + + It("has a string representation for the default value", func() { + var c ConnectionID + Expect(c.String()).To(Equal("(empty)")) + }) +}) diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 2821d2cd..e89b2227 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -52,9 +52,6 @@ func (t PacketType) String() string { } } -// A ConnectionID in QUIC -type ConnectionID uint64 - // A ByteCount in QUIC type ByteCount uint64 diff --git a/internal/utils/connection_id.go b/internal/utils/connection_id.go deleted file mode 100644 index b4af4e78..00000000 --- a/internal/utils/connection_id.go +++ /dev/null @@ -1,18 +0,0 @@ -package utils - -import ( - "crypto/rand" - "encoding/binary" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID() (protocol.ConnectionID, error) { - b := make([]byte, 8) - _, err := rand.Read(b) - if err != nil { - return 0, err - } - return protocol.ConnectionID(binary.LittleEndian.Uint64(b)), nil -} diff --git a/internal/utils/connection_id_test.go b/internal/utils/connection_id_test.go deleted file mode 100644 index cd3f9af1..00000000 --- a/internal/utils/connection_id_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package utils - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Connection ID generation", func() { - It("generates random connection IDs", func() { - c1, err := GenerateConnectionID() - Expect(err).ToNot(HaveOccurred()) - Expect(c1).ToNot(BeZero()) - c2, err := GenerateConnectionID() - Expect(err).ToNot(HaveOccurred()) - Expect(c1).ToNot(Equal(c2)) - }) -}) diff --git a/internal/wire/header.go b/internal/wire/header.go index fc346f3f..c4c6e813 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -10,12 +10,16 @@ import ( // Header is the header of a QUIC packet. // It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header. type Header struct { - Raw []byte - ConnectionID protocol.ConnectionID + Raw []byte + + Version protocol.VersionNumber + + DestConnectionID protocol.ConnectionID + SrcConnectionID protocol.ConnectionID OmitConnectionID bool - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - Version protocol.VersionNumber // VersionNumber sent by the client + + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber IsVersionNegotiation bool SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 1533915f..9144da7d 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -23,10 +23,12 @@ var _ = Describe("Header", func() { buf := &bytes.Buffer{} // use a Short Header, which isn't distinguishable from the gQUIC Public Header when looking at the type byte err := (&Header{ - IsLongHeader: false, - KeyPhase: 1, - PacketNumber: 0x42, - PacketNumberLen: protocol.PacketNumberLen2, + IsLongHeader: false, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + KeyPhase: 1, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) hdr, err := ParseHeaderSentByClient(bytes.NewReader(buf.Bytes())) @@ -39,10 +41,12 @@ var _ = Describe("Header", func() { It("parses an IETF draft header, when the version is not known, but it has Long Header format", func() { buf := &bytes.Buffer{} err := (&Header{ - IsLongHeader: true, - Type: protocol.PacketType0RTT, - PacketNumber: 0x42, - Version: 0x1234, + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Type: protocol.PacketType0RTT, + PacketNumber: 0x42, + Version: 0x1234, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) hdr, err := ParseHeaderSentByClient(bytes.NewReader(buf.Bytes())) @@ -57,9 +61,11 @@ var _ = Describe("Header", func() { // make sure this packet could be mistaken for a Version Negotiation Packet, if we only look at the 0x1 bit buf := &bytes.Buffer{} err := (&Header{ - IsLongHeader: false, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, + IsLongHeader: false, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) hdr, err := ParseHeaderSentByServer(bytes.NewReader(buf.Bytes()), versionIETFHeader) @@ -68,26 +74,32 @@ var _ = Describe("Header", func() { }) It("parses a gQUIC Public Header, when the version is not known", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} buf := &bytes.Buffer{} err := (&Header{ - VersionFlag: true, - Version: versionPublicHeader, - ConnectionID: 0x42, - PacketNumber: 0x1337, - PacketNumberLen: protocol.PacketNumberLen6, + VersionFlag: true, + Version: versionPublicHeader, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen6, }).writePublicHeader(buf, protocol.PerspectiveClient, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) hdr, err := ParseHeaderSentByClient(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(Equal(connID)) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(hdr.Version).To(Equal(versionPublicHeader)) Expect(hdr.isPublicHeader).To(BeTrue()) }) It("parses a gQUIC Public Header, when the version is known", func() { + connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} buf := &bytes.Buffer{} err := (&Header{ - ConnectionID: 0x42, + DestConnectionID: connID, + SrcConnectionID: connID, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen6, DiversificationNonce: bytes.Repeat([]byte{'f'}, 32), @@ -95,6 +107,8 @@ var _ = Describe("Header", func() { Expect(err).ToNot(HaveOccurred()) hdr, err := ParseHeaderSentByServer(bytes.NewReader(buf.Bytes()), versionPublicHeader) Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(Equal(connID)) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(hdr.DiversificationNonce).To(HaveLen(32)) Expect(hdr.isPublicHeader).To(BeTrue()) @@ -103,11 +117,12 @@ var _ = Describe("Header", func() { It("errors when parsing the gQUIC header fails", func() { buf := &bytes.Buffer{} err := (&Header{ - VersionFlag: true, - Version: versionPublicHeader, - ConnectionID: 0x42, - PacketNumber: 0x1337, - PacketNumberLen: protocol.PacketNumberLen6, + VersionFlag: true, + Version: versionPublicHeader, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen6, }).writePublicHeader(buf, protocol.PerspectiveClient, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) _, err = ParseHeaderSentByClient(bytes.NewReader(buf.Bytes()[0:12])) @@ -122,12 +137,14 @@ var _ = Describe("Header", func() { }) It("parses a gQUIC Version Negotiation Packet", func() { + connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0xde, 0xca, 0xfb, 0xad} versions := []protocol.VersionNumber{0x13, 0x37} - data := ComposeGQUICVersionNegotiation(0x42, versions) + data := ComposeGQUICVersionNegotiation(connID, versions) hdr, err := ParseHeaderSentByServer(bytes.NewReader(data), protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.isPublicHeader).To(BeTrue()) - Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x42))) + Expect(hdr.DestConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(Equal(connID)) // in addition to the versions, the supported versions might contain a reserved version number for _, version := range versions { Expect(hdr.SupportedVersions).To(ContainElement(version)) @@ -135,13 +152,17 @@ var _ = Describe("Header", func() { }) It("parses an IETF draft style Version Negotiation Packet", func() { + destConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x13, 0x37} - data := ComposeVersionNegotiation(0x42, versions) + data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) + Expect(err).ToNot(HaveOccurred()) hdr, err := ParseHeaderSentByServer(bytes.NewReader(data), protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.isPublicHeader).To(BeFalse()) Expect(hdr.IsVersionNegotiation).To(BeTrue()) - Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x42))) + Expect(hdr.DestConnectionID).To(Equal(destConnID)) + Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) Expect(hdr.Version).To(BeZero()) // in addition to the versions, the supported versions might contain a reserved version number for _, version := range versions { @@ -154,9 +175,10 @@ var _ = Describe("Header", func() { It("writes a gQUIC Public Header", func() { buf := &bytes.Buffer{} hdr := &Header{ - ConnectionID: 0x1337, - PacketNumber: 0x42, - PacketNumberLen: protocol.PacketNumberLen2, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, } err := hdr.Write(buf, protocol.PerspectiveServer, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) @@ -168,10 +190,11 @@ var _ = Describe("Header", func() { It("writes a IETF draft header", func() { buf := &bytes.Buffer{} hdr := &Header{ - ConnectionID: 0x1337, - PacketNumber: 0x42, - PacketNumberLen: protocol.PacketNumberLen2, - KeyPhase: 1, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + KeyPhase: 1, } err := hdr.Write(buf, protocol.PerspectiveServer, versionIETFHeader) Expect(err).ToNot(HaveOccurred()) @@ -185,7 +208,8 @@ var _ = Describe("Header", func() { It("get the length of a gQUIC Public Header", func() { buf := &bytes.Buffer{} hdr := &Header{ - ConnectionID: 0x1337, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen2, DiversificationNonce: bytes.Repeat([]byte{'f'}, 32), @@ -205,11 +229,12 @@ var _ = Describe("Header", func() { It("get the length of a a IETF draft header", func() { buf := &bytes.Buffer{} hdr := &Header{ - IsLongHeader: true, - ConnectionID: 0x1337, - PacketNumber: 0x42, - PacketNumberLen: protocol.PacketNumberLen2, - KeyPhase: 1, + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + KeyPhase: 1, } err := hdr.Write(buf, protocol.PerspectiveServer, versionIETFHeader) Expect(err).ToNot(HaveOccurred()) @@ -243,14 +268,18 @@ var _ = Describe("Header", func() { It("logs an IETF draft header", func() { (&Header{ - IsLongHeader: true, + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, }).Log(logger) Expect(buf.String()).To(ContainSubstring("Long Header")) }) It("logs a Public Header", func() { (&Header{ - isPublicHeader: true, + isPublicHeader: true, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, }).Log(logger) Expect(buf.String()).To(ContainSubstring("Public Header")) }) diff --git a/internal/wire/ietf_header.go b/internal/wire/ietf_header.go index 01bf0a26..811c7446 100644 --- a/internal/wire/ietf_header.go +++ b/internal/wire/ietf_header.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "io" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -24,18 +25,31 @@ func parseHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, e // parse long header and version negotiation packets func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte) (*Header, error) { - connID, err := utils.BigEndian.ReadUint64(b) - if err != nil { - return nil, err - } v, err := utils.BigEndian.ReadUint32(b) if err != nil { return nil, err } - h := &Header{ - ConnectionID: protocol.ConnectionID(connID), - Version: protocol.VersionNumber(v), + + connIDLenByte, err := b.ReadByte() + if err != nil { + return nil, err } + dcil, scil := decodeConnIDLen(connIDLenByte) + destConnID, err := protocol.ReadConnectionID(b, dcil) + if err != nil { + return nil, err + } + srcConnID, err := protocol.ReadConnectionID(b, scil) + if err != nil { + return nil, err + } + + h := &Header{ + Version: protocol.VersionNumber(v), + DestConnectionID: destConnID, + SrcConnectionID: srcConnID, + } + if v == 0 { // version negotiation packet if sentBy == protocol.PerspectiveClient { return nil, qerr.InvalidVersion @@ -54,6 +68,7 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte } return h, nil } + h.IsLongHeader = true pn, err := utils.BigEndian.ReadUint32(b) if err != nil { @@ -72,21 +87,19 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte } func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) { - omitConnID := typeByte&0x40 > 0 - var connID uint64 - if !omitConnID { - var err error - connID, err = utils.BigEndian.ReadUint64(b) - if err != nil { - return nil, err + connID := make(protocol.ConnectionID, 8) + if _, err := io.ReadFull(b, connID); err != nil { + if err == io.ErrUnexpectedEOF { + err = io.EOF } + return nil, err } - // bit 4 must be set, bit 5 must be unset - if typeByte&0x18 != 0x10 { - return nil, errors.New("invalid bit 4 and 5") + // bits 2 and 3 must be set, bit 4 must be unset + if typeByte&0x38 != 0x30 { + return nil, errors.New("invalid bits 3, 4 and 5") } var pnLen protocol.PacketNumberLen - switch typeByte & 0x7 { + switch typeByte & 0x3 { case 0x0: pnLen = protocol.PacketNumberLen1 case 0x1: @@ -101,9 +114,8 @@ func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) { return nil, err } return &Header{ - KeyPhase: int(typeByte&0x20) >> 5, - OmitConnectionID: omitConnID, - ConnectionID: protocol.ConnectionID(connID), + KeyPhase: int(typeByte&0x40) >> 6, + DestConnectionID: connID, PacketNumber: protocol.PacketNumber(pn), PacketNumberLen: pnLen, }, nil @@ -119,33 +131,40 @@ func (h *Header) writeHeader(b *bytes.Buffer) error { // TODO: add support for the key phase func (h *Header) writeLongHeader(b *bytes.Buffer) error { + if !h.DestConnectionID.Equal(h.SrcConnectionID) { + return errors.New("Header: can't write a header with different source and destination connection ID") + } + if h.SrcConnectionID.Len() != 8 { + return fmt.Errorf("Header: source connection ID must be 8 bytes, is %d", h.SrcConnectionID.Len()) + } b.WriteByte(byte(0x80 | h.Type)) - utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) utils.BigEndian.WriteUint32(b, uint32(h.Version)) + connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) + if err != nil { + return err + } + b.WriteByte(connIDLen) + b.Write(h.DestConnectionID.Bytes()) + b.Write(h.SrcConnectionID.Bytes()) utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) return nil } func (h *Header) writeShortHeader(b *bytes.Buffer) error { - typeByte := byte(0x10) - typeByte ^= byte(h.KeyPhase << 5) - if h.OmitConnectionID { - typeByte ^= 0x40 - } + typeByte := byte(0x30) + typeByte |= byte(h.KeyPhase << 6) switch h.PacketNumberLen { case protocol.PacketNumberLen1: case protocol.PacketNumberLen2: - typeByte ^= 0x1 + typeByte |= 0x1 case protocol.PacketNumberLen4: - typeByte ^= 0x2 + typeByte |= 0x2 default: return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) } b.WriteByte(typeByte) - if !h.OmitConnectionID { - utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) - } + b.Write(h.DestConnectionID.Bytes()) switch h.PacketNumberLen { case protocol.PacketNumberLen1: b.WriteByte(uint8(h.PacketNumber)) @@ -160,13 +179,10 @@ func (h *Header) writeShortHeader(b *bytes.Buffer) error { // getHeaderLength gets the length of the Header in bytes. func (h *Header) getHeaderLength() (protocol.ByteCount, error) { if h.IsLongHeader { - return 1 + 8 + 4 + 4, nil + return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + 4 /* packet number */, nil } - length := protocol.ByteCount(1) // type byte - if !h.OmitConnectionID { - length += 8 - } + length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) } @@ -176,12 +192,42 @@ func (h *Header) getHeaderLength() (protocol.ByteCount, error) { func (h *Header) logHeader(logger utils.Logger) { if h.IsLongHeader { - logger.Debugf(" Long Header{Type: %s, ConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.ConnectionID, h.PacketNumber, h.Version) + logger.Debugf(" Long Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.Version) } else { - connID := "(omitted)" - if !h.OmitConnectionID { - connID = fmt.Sprintf("%#x", h.ConnectionID) - } - logger.Debugf(" Short Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", connID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + logger.Debugf(" Short Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) } } + +func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) { + dcil, err := encodeSingleConnIDLen(dest) + if err != nil { + return 0, err + } + scil, err := encodeSingleConnIDLen(src) + if err != nil { + return 0, err + } + return scil | dcil<<4, nil +} + +func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) { + len := id.Len() + if len == 0 { + return 0, nil + } + if len < 4 || len > 18 { + return 0, errors.New("invalid connection ID length") + } + return byte(len - 3), nil +} + +func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) { + return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf) +} + +func decodeSingleConnIDLen(enc uint8) int { + if enc == 0 { + return 0 + } + return int(enc) + 3 +} diff --git a/internal/wire/ietf_header_test.go b/internal/wire/ietf_header_test.go index 4097bc6a..ee32f3de 100644 --- a/internal/wire/ietf_header_test.go +++ b/internal/wire/ietf_header_test.go @@ -15,36 +15,43 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("IETF draft Header", func() { +var _ = Describe("IETF QUIC Header", func() { Context("parsing", func() { Context("Version Negotiation Packets", func() { It("parses", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455, 0x33445566} - data := ComposeVersionNegotiation(0x1234567890, versions) + data, err := ComposeVersionNegotiation(connID, connID, versions) + Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) h, err := parseHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(h.IsVersionNegotiation).To(BeTrue()) Expect(h.Version).To(BeZero()) - Expect(h.ConnectionID).To(Equal(protocol.ConnectionID(0x1234567890))) + Expect(h.DestConnectionID).To(Equal(connID)) + Expect(h.SrcConnectionID).To(Equal(connID)) for _, v := range versions { Expect(h.SupportedVersions).To(ContainElement(v)) } }) It("errors if it contains versions of the wrong length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455, 0x33445566} - data := ComposeVersionNegotiation(0x1234567890, versions) + data, err := ComposeVersionNegotiation(connID, connID, versions) + Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data[:len(data)-2]) - _, err := parseHeader(b, protocol.PerspectiveServer) + _, err = parseHeader(b, protocol.PerspectiveServer) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) }) It("errors if the version list is empty", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455} - data := ComposeVersionNegotiation(0x1234567890, versions) + data, err := ComposeVersionNegotiation(connID, connID, versions) + Expect(err).ToNot(HaveOccurred()) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number - _, err := parseHeader(bytes.NewReader(data[:len(data)-8]), protocol.PerspectiveServer) + _, err = parseHeader(bytes.NewReader(data[:len(data)-8]), protocol.PerspectiveServer) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list")) }) }) @@ -53,8 +60,10 @@ var _ = Describe("IETF draft Header", func() { generatePacket := func(t protocol.PacketType) []byte { return []byte{ 0x80 ^ uint8(t), - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0x1, 0x2, 0x3, 0x4, // version number + 0x55, // connection ID lengths + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // destination connection ID + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // source connection ID 0xde, 0xca, 0xfb, 0xad, // packet number } } @@ -66,7 +75,8 @@ var _ = Describe("IETF draft Header", func() { Expect(h.Type).To(Equal(protocol.PacketTypeInitial)) Expect(h.IsLongHeader).To(BeTrue()) Expect(h.OmitConnectionID).To(BeFalse()) - Expect(h.ConnectionID).To(Equal(protocol.ConnectionID(0xdeadbeefcafe1337))) + Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) + Expect(h.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad))) Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) Expect(h.Version).To(Equal(protocol.VersionNumber(0x1020304))) @@ -95,8 +105,10 @@ var _ = Describe("IETF draft Header", func() { It("rejects version 0 for packets sent by the client", func() { data := []byte{ 0x80 ^ uint8(protocol.PacketTypeInitial), - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0x0, 0x0, 0x0, 0x0, // version number + 0x55, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // dest connection ID + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // src connection ID 0xde, 0xca, 0xfb, 0xad, // packet number } _, err := parseHeader(bytes.NewReader(data), protocol.PerspectiveClient) @@ -115,7 +127,7 @@ var _ = Describe("IETF draft Header", func() { Context("short headers", func() { It("reads a short header with a connection ID", func() { data := []byte{ - 0x10, // 1 byte packet number + 0x30, // 1 byte packet number 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0x42, // packet number } @@ -125,7 +137,8 @@ var _ = Describe("IETF draft Header", func() { Expect(h.IsLongHeader).To(BeFalse()) Expect(h.KeyPhase).To(Equal(0)) Expect(h.OmitConnectionID).To(BeFalse()) - Expect(h.ConnectionID).To(Equal(protocol.ConnectionID(0xdeadbeefcafe1337))) + Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) + Expect(h.SrcConnectionID).To(BeEmpty()) Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) Expect(h.IsVersionNegotiation).To(BeFalse()) Expect(b.Len()).To(BeZero()) @@ -133,7 +146,8 @@ var _ = Describe("IETF draft Header", func() { It("reads the Key Phase Bit", func() { data := []byte{ - 0x10 ^ 0x40 ^ 0x20, + 0x30 ^ 0x40, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0x11, } b := bytes.NewReader(data) @@ -144,24 +158,10 @@ var _ = Describe("IETF draft Header", func() { Expect(b.Len()).To(BeZero()) }) - It("reads a header with omitted connection ID", func() { - data := []byte{ - 0x10 ^ 0x40, - 0x21, // packet number - } - b := bytes.NewReader(data) - h, err := parseHeader(b, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(h.IsLongHeader).To(BeFalse()) - Expect(h.OmitConnectionID).To(BeTrue()) - Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x21))) - Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) - Expect(b.Len()).To(BeZero()) - }) - It("reads a header with a 2 byte packet number", func() { data := []byte{ - 0x10 ^ 0x40 ^ 0x1, + 0x30 ^ 0x40 ^ 0x1, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0x13, 0x37, // packet number } b := bytes.NewReader(data) @@ -175,7 +175,8 @@ var _ = Describe("IETF draft Header", func() { It("reads a header with a 4 byte packet number", func() { data := []byte{ - 0x10 ^ 0x40 ^ 0x2, + 0x30 ^ 0x40 ^ 0x2, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0xde, 0xad, 0xbe, 0xef, // packet number } b := bytes.NewReader(data) @@ -187,9 +188,31 @@ var _ = Describe("IETF draft Header", func() { Expect(b.Len()).To(BeZero()) }) + It("rejects headers that have an invalid type", func() { + data := []byte{ + 0x30 ^ 0x40 ^ 0x3, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + 0xde, 0xad, 0xbe, 0xef, // packet number + } + b := bytes.NewReader(data) + _, err := parseHeader(b, protocol.PerspectiveClient) + Expect(err).To(MatchError("invalid short header type")) + }) + + It("rejects headers that have bit 3,4 and 5 set incorrectly", func() { + data := []byte{ + 0x38 ^ 0x2, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + 0xde, 0xca, 0xfb, 0xad, // packet number + } + b := bytes.NewReader(data) + _, err := parseHeader(b, protocol.PerspectiveClient) + Expect(err).To(MatchError("invalid bits 3, 4 and 5")) + }) + It("errors on EOF", func() { data := []byte{ - 0x10 ^ 0x2, + 0x30 ^ 0x2, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0xde, 0xca, 0xfb, 0xad, // packet number } @@ -211,32 +234,44 @@ var _ = Describe("IETF draft Header", func() { Context("long header", func() { It("writes", func() { err := (&Header{ - IsLongHeader: true, - Type: 0x5, - ConnectionID: 0xdeadbeefcafe1337, - PacketNumber: 0xdecafbad, - Version: 0x1020304, + IsLongHeader: true, + Type: 0x5, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + PacketNumber: 0xdecafbad, + Version: 0x1020304, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ 0x80 ^ 0x5, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0x1, 0x2, 0x3, 0x4, // version number + 0x55, // connection ID lengths + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0xde, 0xca, 0xfb, 0xad, // packet number })) }) + + It("refuses to write headers with unequal destination and source connection IDs", func() { + err := (&Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + }).writeHeader(buf) + Expect(err).To(MatchError("Header: can't write a header with different source and destination connection ID")) + }) }) Context("short header", func() { It("writes a header with connection ID", func() { err := (&Header{ - ConnectionID: 0xdeadbeefcafe1337, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ - 0x10, + 0x30, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID 0x42, // packet number })) @@ -244,13 +279,12 @@ var _ = Describe("IETF draft Header", func() { It("writes a header without connection ID", func() { err := (&Header{ - OmitConnectionID: true, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ - 0x10 ^ 0x40, + 0x30, 0x42, // packet number })) }) @@ -263,7 +297,7 @@ var _ = Describe("IETF draft Header", func() { }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ - 0x10 ^ 0x40 ^ 0x1, + 0x30 | 0x1, 0x13, 0x37, // packet number })) }) @@ -276,7 +310,7 @@ var _ = Describe("IETF draft Header", func() { }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ - 0x10 ^ 0x40 ^ 0x2, + 0x30 | 0x2, 0xde, 0xca, 0xfb, 0xad, // packet number })) }) @@ -299,7 +333,7 @@ var _ = Describe("IETF draft Header", func() { }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ - 0x10 ^ 0x40 ^ 0x20, + 0x30 | 0x40, 0x42, // packet number })) }) @@ -314,16 +348,22 @@ var _ = Describe("IETF draft Header", func() { }) It("has the right length for the long header", func() { - h := &Header{IsLongHeader: true} - Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(17))) + h := &Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 4 /* packet number */ + Expect(h.getHeaderLength()).To(BeEquivalentTo(expectedLen)) err := h.writeHeader(buf) Expect(err).ToNot(HaveOccurred()) - Expect(buf.Len()).To(Equal(17)) + Expect(buf.Len()).To(Equal(expectedLen)) }) It("has the right length for a short header containing a connection ID", func() { h := &Header{ - PacketNumberLen: protocol.PacketNumberLen1, + PacketNumberLen: protocol.PacketNumberLen1, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, } Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 8 + 1))) err := h.writeHeader(buf) @@ -390,32 +430,24 @@ var _ = Describe("IETF draft Header", func() { It("logs Long Headers", func() { (&Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - PacketNumber: 0x1337, - ConnectionID: 0xdeadbeef, - Version: 0xfeed, + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + PacketNumber: 0x1337, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}, + Version: 0xfeed, }).logHeader(logger) - Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, ConnectionID: 0xdeadbeef, PacketNumber: 0x1337, Version: 0xfeed}")) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: 0xdeadbeefcafe1337, SrcConnectionID: 0xdecafbad13371337, PacketNumber: 0x1337, Version: 0xfeed}")) }) It("logs Short Headers containing a connection ID", func() { (&Header{ - KeyPhase: 1, - PacketNumber: 0x1337, - PacketNumberLen: 4, - ConnectionID: 0xdeadbeef, + KeyPhase: 1, + PacketNumber: 0x1337, + PacketNumberLen: 4, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, }).logHeader(logger) - Expect(buf.String()).To(ContainSubstring("Short Header{ConnectionID: 0xdeadbeef, PacketNumber: 0x1337, PacketNumberLen: 4, KeyPhase: 1}")) - }) - - It("logs Short Headers with omitted connection ID", func() { - (&Header{ - PacketNumber: 0x12, - PacketNumberLen: 1, - OmitConnectionID: true, - }).logHeader(logger) - Expect(buf.String()).To(ContainSubstring("Short Header{ConnectionID: (omitted), PacketNumber: 0x12, PacketNumberLen: 1, KeyPhase: 0}")) + Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: 0xdeadbeefcafe1337, PacketNumber: 0x1337, PacketNumberLen: 4, KeyPhase: 1}")) }) }) }) diff --git a/internal/wire/public_header.go b/internal/wire/public_header.go index af996b29..286b4841 100644 --- a/internal/wire/public_header.go +++ b/internal/wire/public_header.go @@ -26,6 +26,12 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ if h.VersionFlag && h.ResetFlag { return errResetAndVersionFlagSet } + if !h.DestConnectionID.Equal(h.SrcConnectionID) { + return fmt.Errorf("PublicHeader: SrcConnectionID must be equal to DestConnectionID") + } + if len(h.DestConnectionID) != 8 { + return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID)) + } publicFlagByte := uint8(0x00) if h.VersionFlag { @@ -59,7 +65,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ b.WriteByte(publicFlagByte) if !h.OmitConnectionID { - utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) + b.Write(h.DestConnectionID) } if h.VersionFlag && pers == protocol.PerspectiveClient { utils.BigEndian.WriteUint32(b, uint32(h.Version)) @@ -126,15 +132,18 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea // Connection ID if !header.OmitConnectionID { - var connID uint64 - connID, err = utils.BigEndian.ReadUint64(b) - if err != nil { + connID := make(protocol.ConnectionID, 8) + if _, err := io.ReadFull(b, connID); err != nil { + if err == io.ErrUnexpectedEOF { + err = io.EOF + } return nil, err } - header.ConnectionID = protocol.ConnectionID(connID) - if header.ConnectionID == 0 { + if connID[0] == 0 && connID[1] == 0 && connID[2] == 0 && connID[3] == 0 && connID[4] == 0 && connID[5] == 0 && connID[6] == 0 && connID[7] == 0 { return nil, errInvalidConnectionID } + header.DestConnectionID = connID + header.SrcConnectionID = connID } // Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server. @@ -232,13 +241,9 @@ func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool { } func (h *Header) logPublicHeader(logger utils.Logger) { - connID := "(omitted)" - if !h.OmitConnectionID { - connID = fmt.Sprintf("%#x", h.ConnectionID) - } ver := "(unset)" if h.Version != 0 { ver = h.Version.String() } - logger.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) + logger.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) } diff --git a/internal/wire/public_header_test.go b/internal/wire/public_header_test.go index f0de8bf5..f1685320 100644 --- a/internal/wire/public_header_test.go +++ b/internal/wire/public_header_test.go @@ -14,6 +14,8 @@ import ( ) var _ = Describe("Public Header", func() { + connID := protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} + Context("when parsing", func() { It("accepts a sample client header", func() { ver := make([]byte, 4) @@ -24,7 +26,9 @@ var _ = Describe("Public Header", func() { Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.IsVersionNegotiation).To(BeFalse()) Expect(hdr.ResetFlag).To(BeFalse()) - Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x4cfa9f9b668619f6))) + connID := protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} + Expect(hdr.DestConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(Equal(connID)) Expect(hdr.Version).To(Equal(protocol.SupportedVersions[0])) Expect(hdr.SupportedVersions).To(BeEmpty()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(1))) @@ -37,12 +41,13 @@ var _ = Describe("Public Header", func() { Expect(err).To(MatchError(errReceivedOmittedConnectionID)) }) - It("accepts aan d connection ID as a client", func() { + It("accepts an omitted connection ID as a client", func() { b := bytes.NewReader([]byte{0x00, 0x01}) hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.OmitConnectionID).To(BeTrue()) - Expect(hdr.ConnectionID).To(BeZero()) + Expect(hdr.DestConnectionID).To(BeEmpty()) + Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(b.Len()).To(BeZero()) }) @@ -52,22 +57,16 @@ var _ = Describe("Public Header", func() { Expect(err).To(MatchError(errInvalidConnectionID)) }) - It("reads a PublicReset packet", func() { + It("parses a PUBLIC_RESET packet", func() { b := bytes.NewReader([]byte{0xa, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}) hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ResetFlag).To(BeTrue()) - Expect(hdr.ConnectionID).ToNot(BeZero()) - }) - - It("parses a public reset packet", func() { - b := bytes.NewReader([]byte{0xa, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}) - hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.ResetFlag).To(BeTrue()) Expect(hdr.VersionFlag).To(BeFalse()) Expect(hdr.IsVersionNegotiation).To(BeFalse()) - Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x0102030405060708))) + connID := protocol.ConnectionID{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8} + Expect(hdr.SrcConnectionID).To(Equal(connID)) + Expect(hdr.DestConnectionID).To(Equal(connID)) }) It("reads a diversification nonce sent by the server", func() { @@ -76,7 +75,8 @@ var _ = Describe("Public Header", func() { b := bytes.NewReader(append(append([]byte{0x0c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}, divNonce...), 0x37)) hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.ConnectionID).To(Not(BeZero())) + Expect(hdr.DestConnectionID).ToNot(BeEmpty()) + Expect(hdr.SrcConnectionID).ToNot(BeEmpty()) Expect(hdr.DiversificationNonce).To(Equal(divNonce)) Expect(b.Len()).To(BeZero()) }) @@ -89,10 +89,13 @@ var _ = Describe("Public Header", func() { } It("parses", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x13, 0x37} - b := bytes.NewReader(ComposeGQUICVersionNegotiation(0x1337, versions)) + b := bytes.NewReader(ComposeGQUICVersionNegotiation(connID, versions)) hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(Equal(connID)) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.Version).To(BeZero()) // unitialized Expect(hdr.IsVersionNegotiation).To(BeTrue()) @@ -124,7 +127,7 @@ var _ = Describe("Public Header", func() { }) It("errors on invalid version tags", func() { - data := ComposeGQUICVersionNegotiation(0x1337, protocol.SupportedVersions) + data := ComposeGQUICVersionNegotiation(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, protocol.SupportedVersions) data = append(data, []byte{0x13, 0x37}...) b := bytes.NewReader(data) _, err := parsePublicHeader(b, protocol.PerspectiveServer) @@ -175,9 +178,10 @@ var _ = Describe("Public Header", func() { It("writes a sample header as a server", func() { b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 2, - PacketNumberLen: protocol.PacketNumberLen6, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 2, + PacketNumberLen: protocol.PacketNumberLen6, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, versionBigEndian) Expect(err).ToNot(HaveOccurred()) @@ -187,19 +191,47 @@ var _ = Describe("Public Header", func() { It("writes a sample header as a client", func() { b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0x1337, - PacketNumberLen: protocol.PacketNumberLen6, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen6, } err := hdr.writePublicHeader(b, protocol.PerspectiveClient, versionBigEndian) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x38, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x0, 0x0, 0x0, 0x0, 0x13, 0x37})) }) - It("refuses to write a Public Header if the PacketNumberLen is not set", func() { + It("refuses to write a Public Header if the source and destination connection IDs are not matching", func() { + b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 1, - PacketNumber: 2, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen6, + } + err := hdr.writePublicHeader(b, protocol.PerspectiveClient, versionBigEndian) + Expect(err).To(MatchError("PublicHeader: SrcConnectionID must be equal to DestConnectionID")) + }) + + It("refuses to write a Public Header if the connection ID has the wrong length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} + hdr := Header{ + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 2, + PacketNumberLen: protocol.PacketNumberLen2, + } + b := &bytes.Buffer{} + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) + Expect(err).To(MatchError("PublicHeader: wrong length for Connection ID: 7 (expected 8)")) + }) + + It("refuses to write a Public Header if the PacketNumberLen is not set", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + hdr := Header{ + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 2, } b := &bytes.Buffer{} err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) @@ -207,9 +239,11 @@ var _ = Describe("Public Header", func() { }) It("omits the connection ID", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, + DestConnectionID: connID, + SrcConnectionID: connID, OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 1, @@ -222,7 +256,8 @@ var _ = Describe("Public Header", func() { It("writes diversification nonces", func() { b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, + DestConnectionID: connID, + SrcConnectionID: connID, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen1, DiversificationNonce: bytes.Repeat([]byte{1}, 32), @@ -249,10 +284,11 @@ var _ = Describe("Public Header", func() { It("doesn't write Version Negotiation Packets", func() { b := &bytes.Buffer{} hdr := Header{ - VersionFlag: true, - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 2, - PacketNumberLen: protocol.PacketNumberLen6, + VersionFlag: true, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 2, + PacketNumberLen: protocol.PacketNumberLen6, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).To(MatchError("PublicHeader: Writing of Version Negotiation Packets not supported")) @@ -261,11 +297,12 @@ var _ = Describe("Public Header", func() { It("writes packets with Version Flag, as a client", func() { b := &bytes.Buffer{} hdr := Header{ - VersionFlag: true, - Version: protocol.Version39, - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0x42, - PacketNumberLen: protocol.PacketNumberLen1, + VersionFlag: true, + Version: protocol.Version39, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen1, } err := hdr.writePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) @@ -281,9 +318,11 @@ var _ = Describe("Public Header", func() { Context("PublicReset packets", func() { It("sets the Reset Flag", func() { b := &bytes.Buffer{} + connID := protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} hdr := Header{ - ResetFlag: true, - ConnectionID: 0x4cfa9f9b668619f6, + ResetFlag: true, + DestConnectionID: connID, + SrcConnectionID: connID, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) @@ -295,11 +334,13 @@ var _ = Describe("Public Header", func() { It("doesn't add a packet number for headers with Reset Flag sent as a client", func() { b := &bytes.Buffer{} + connID := protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} hdr := Header{ - ResetFlag: true, - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 2, - PacketNumberLen: protocol.PacketNumberLen6, + ResetFlag: true, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 2, + PacketNumberLen: protocol.PacketNumberLen6, } err := hdr.writePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) @@ -326,8 +367,9 @@ var _ = Describe("Public Header", func() { It("errors when PacketNumberLen is not set", func() { hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xdecafbad, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0xdecafbad, } _, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).To(MatchError(errPacketNumberLenNotSet)) @@ -335,9 +377,10 @@ var _ = Describe("Public Header", func() { It("gets the length of a packet with longest packet number length and connectionID", func() { hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen6, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen6, } length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) @@ -346,7 +389,8 @@ var _ = Describe("Public Header", func() { It("gets the lengths of a packet sent by the client with the VersionFlag set", func() { hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, + DestConnectionID: connID, + SrcConnectionID: connID, OmitConnectionID: true, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen6, @@ -360,7 +404,8 @@ var _ = Describe("Public Header", func() { It("gets the length of a packet with longest packet number length and omitted connectionID", func() { hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, + DestConnectionID: connID, + SrcConnectionID: connID, OmitConnectionID: true, PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen6, @@ -372,9 +417,10 @@ var _ = Describe("Public Header", func() { It("gets the length of a packet 2 byte packet number length ", func() { hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xDECAFBAD, - PacketNumberLen: protocol.PacketNumberLen2, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0xDECAFBAD, + PacketNumberLen: protocol.PacketNumberLen2, } length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) @@ -393,8 +439,9 @@ var _ = Describe("Public Header", func() { It("gets the length of a PublicReset", func() { hdr := Header{ - ResetFlag: true, - ConnectionID: 0x4cfa9f9b668619f6, + ResetFlag: true, + DestConnectionID: connID, + SrcConnectionID: connID, } length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).NotTo(HaveOccurred()) @@ -406,8 +453,9 @@ var _ = Describe("Public Header", func() { It("doesn't write a header if the packet number length is not set", func() { b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xDECAFBAD, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0xDECAFBAD, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).To(MatchError("PublicHeader: PacketNumberLen not set")) @@ -419,9 +467,10 @@ var _ = Describe("Public Header", func() { It("writes a header with a 1-byte packet number", func() { b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen1, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen1, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) @@ -431,9 +480,10 @@ var _ = Describe("Public Header", func() { It("writes a header with a 2-byte packet number", func() { b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen2, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen2, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) @@ -443,9 +493,10 @@ var _ = Describe("Public Header", func() { It("writes a header with a 4-byte packet number", func() { b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0x13decafbad, - PacketNumberLen: protocol.PacketNumberLen4, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0x13decafbad, + PacketNumberLen: protocol.PacketNumberLen4, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) @@ -455,9 +506,10 @@ var _ = Describe("Public Header", func() { It("writes a header with a 6-byte packet number", func() { b := &bytes.Buffer{} hdr := Header{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xbe1337decafbad, - PacketNumberLen: protocol.PacketNumberLen6, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0xbe1337decafbad, + PacketNumberLen: protocol.PacketNumberLen6, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) @@ -486,12 +538,13 @@ var _ = Describe("Public Header", func() { It("logs a Public Header containing a connection ID", func() { (&Header{ - ConnectionID: 0xdecafbad, - PacketNumber: 0x1337, - PacketNumberLen: 6, - Version: protocol.Version39, + DestConnectionID: protocol.ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}, + SrcConnectionID: protocol.ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}, + PacketNumber: 0x1337, + PacketNumberLen: 6, + Version: protocol.Version39, }).logPublicHeader(logger) - Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: 0xdecafbad, PacketNumber: 0x1337, PacketNumberLen: 6, Version: gQUIC 39")) + Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: 0x13370000decafbad, PacketNumber: 0x1337, PacketNumberLen: 6, Version: gQUIC 39")) }) It("logs a Public Header with omitted connection ID", func() { @@ -501,7 +554,7 @@ var _ = Describe("Public Header", func() { PacketNumberLen: 6, Version: protocol.Version39, }).logPublicHeader(logger) - Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: (omitted)")) + Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: (empty)")) }) It("logs a Public Header without a version", func() { @@ -515,7 +568,8 @@ var _ = Describe("Public Header", func() { It("logs diversification nonces", func() { (&Header{ - ConnectionID: 0xdecafbad, + DestConnectionID: []byte{0x13, 0x13, 0, 0, 0xde, 0xca, 0xfb, 0xad}, + SrcConnectionID: []byte{0x13, 0x13, 0, 0, 0xde, 0xca, 0xfb, 0xad}, DiversificationNonce: []byte{0xba, 0xdf, 0x00, 0x0d}, }).logPublicHeader(logger) Expect(buf.String()).To(ContainSubstring("DiversificationNonce: []byte{0xba, 0xdf, 0x0, 0xd}")) diff --git a/internal/wire/public_reset.go b/internal/wire/public_reset.go index 6adc9f69..b57ea7ad 100644 --- a/internal/wire/public_reset.go +++ b/internal/wire/public_reset.go @@ -16,11 +16,11 @@ type PublicReset struct { Nonce uint64 } -// WritePublicReset writes a Public Reset +// WritePublicReset writes a PUBLIC_RESET func WritePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { b := &bytes.Buffer{} b.WriteByte(0x0a) - utils.BigEndian.WriteUint64(b, uint64(connectionID)) + b.Write(connectionID) utils.LittleEndian.WriteUint32(b, uint32(handshake.TagPRST)) utils.LittleEndian.WriteUint32(b, 2) utils.LittleEndian.WriteUint32(b, uint32(handshake.TagRNON)) @@ -32,7 +32,7 @@ func WritePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p return b.Bytes() } -// ParsePublicReset parses a Public Reset +// ParsePublicReset parses a PUBLIC_RESET func ParsePublicReset(r *bytes.Reader) (*PublicReset, error) { pr := PublicReset{} msg, err := handshake.ParseHandshakeMessage(r) @@ -44,7 +44,7 @@ func ParsePublicReset(r *bytes.Reader) (*PublicReset, error) { } // The RSEQ tag is mandatory according to the gQUIC wire spec. - // However, Google doesn't send RSEQ in their Public Resets. + // However, Google doesn't send RSEQ in their PUBLIC_RESETs. // Therefore, we'll treat RSEQ as an optional field. if rseq, ok := msg.Data[handshake.TagRSEQ]; ok { if len(rseq) != 8 { diff --git a/internal/wire/public_reset_test.go b/internal/wire/public_reset_test.go index 311ab6b7..45347df6 100644 --- a/internal/wire/public_reset_test.go +++ b/internal/wire/public_reset_test.go @@ -13,7 +13,7 @@ import ( var _ = Describe("public reset", func() { Context("writing", func() { It("writes public reset packets", func() { - Expect(WritePublicReset(0xdeadbeef, 0x8badf00d, 0xdecafbad)).To(Equal([]byte{ + Expect(WritePublicReset(protocol.ConnectionID{0, 0, 0, 0, 0xde, 0xad, 0xbe, 0xef}, 0x8badf00d, 0xdecafbad)).To(Equal([]byte{ 0x0a, 0x0, 0x0, 0x0, 0x0, 0xde, 0xad, 0xbe, 0xef, 'P', 'R', 'S', 'T', @@ -36,7 +36,7 @@ var _ = Describe("public reset", func() { }) It("parses a public reset", func() { - packet := WritePublicReset(0xdeadbeef, 0x8badf00d, 0xdecafbad) + packet := WritePublicReset(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, 0x8badf00d, 0xdecafbad) pr, err := ParsePublicReset(bytes.NewReader(packet[9:])) // 1 byte Public Flag, 8 bytes connection ID Expect(err).ToNot(HaveOccurred()) Expect(pr.Nonce).To(Equal(uint64(0xdecafbad))) diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index cf72fc2e..a19f2767 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -12,7 +12,7 @@ import ( func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { buf := bytes.NewBuffer(make([]byte, 0, 1+8+len(versions)*4)) buf.Write([]byte{0x1 | 0x8}) // type byte - utils.BigEndian.WriteUint64(buf, uint64(connID)) + buf.Write(connID) for _, v := range versions { utils.BigEndian.WriteUint32(buf, uint32(v)) } @@ -20,19 +20,22 @@ func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []pro } // ComposeVersionNegotiation composes a Version Negotiation according to the IETF draft -func ComposeVersionNegotiation( - connID protocol.ConnectionID, - versions []protocol.VersionNumber, -) []byte { +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) ([]byte, error) { greasedVersions := protocol.GetGreasedVersions(versions) buf := bytes.NewBuffer(make([]byte, 0, 1+8+4+len(greasedVersions)*4)) r := make([]byte, 1) _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. buf.WriteByte(r[0] | 0x80) - utils.BigEndian.WriteUint64(buf, uint64(connID)) utils.BigEndian.WriteUint32(buf, 0) // version 0 + connIDLen, err := encodeConnIDLen(destConnID, srcConnID) + if err != nil { + return nil, err + } + buf.WriteByte(connIDLen) + buf.Write(destConnID) + buf.Write(srcConnID) for _, v := range greasedVersions { utils.BigEndian.WriteUint32(buf, uint32(v)) } - return buf.Bytes() + return buf.Bytes(), nil } diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 98783d4c..cff07895 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -10,23 +10,29 @@ import ( var _ = Describe("Version Negotiation Packets", func() { It("writes for gQUIC", func() { + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} versions := []protocol.VersionNumber{1001, 1003} - data := ComposeGQUICVersionNegotiation(0x1337, versions) + data := ComposeGQUICVersionNegotiation(connID, versions) hdr, err := parsePublicHeader(bytes.NewReader(data), protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) - Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) + Expect(hdr.DestConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(Equal(connID)) Expect(hdr.SupportedVersions).To(Equal(versions)) }) It("writes in IETF draft style", func() { + srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} + destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{1001, 1003} - data := ComposeVersionNegotiation(0x1337, versions) + data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) + Expect(err).ToNot(HaveOccurred()) Expect(data[0] & 0x80).ToNot(BeZero()) hdr, err := parseHeader(bytes.NewReader(data), protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsVersionNegotiation).To(BeTrue()) - Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) + Expect(hdr.DestConnectionID).To(Equal(destConnID)) + Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) Expect(hdr.Version).To(BeZero()) // the supported versions should include one reserved version number Expect(hdr.SupportedVersions).To(HaveLen(len(versions) + 1)) diff --git a/mint_utils.go b/mint_utils.go index b32a0905..36af76d0 100644 --- a/mint_utils.go +++ b/mint_utils.go @@ -137,7 +137,7 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, logger return nil, errors.New("received stream data with non-zero offset") } if logger.Debug() { - logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) + logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID) hdr.Log(logger) wire.LogFrame(logger, frame, false) } @@ -160,7 +160,7 @@ func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, per _ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex]) raw = raw[0 : buffer.Len()+aead.Overhead()] if logger.Debug() { - logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) + logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.SrcConnectionID, protocol.EncryptionUnencrypted) hdr.Log(logger) wire.LogFrame(logger, f, true) } diff --git a/mint_utils_test.go b/mint_utils_test.go index 5424277d..0239a9a4 100644 --- a/mint_utils_test.go +++ b/mint_utils_test.go @@ -17,14 +17,15 @@ import ( var _ = Describe("Packing and unpacking Initial packets", func() { var aead crypto.AEAD - connID := protocol.ConnectionID(0x1337) + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} ver := protocol.VersionTLS hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - PacketNumber: 0x42, - ConnectionID: connID, - Version: ver, + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + PacketNumber: 0x42, + DestConnectionID: connID, + SrcConnectionID: connID, + Version: ver, } BeforeEach(func() { diff --git a/packet_packer.go b/packet_packer.go index e4fa653a..a4dd434f 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -446,9 +446,10 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header packetNumberLen := p.getPacketNumberLen(pnum) header := &wire.Header{ - ConnectionID: p.connectionID, - PacketNumber: pnum, - PacketNumberLen: packetNumberLen, + DestConnectionID: p.connectionID, + SrcConnectionID: p.connectionID, + PacketNumber: pnum, + PacketNumberLen: packetNumberLen, } if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { diff --git a/packet_packer_test.go b/packet_packer_test.go index f38f5c58..5b2b719a 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -66,7 +66,7 @@ var _ = Describe("Packet packer", func() { divNonce = bytes.Repeat([]byte{'e'}, 32) packer = newPacketPacker( - 0x1337, + protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}, 1, func(protocol.PacketNumber) protocol.PacketNumberLen { return protocol.PacketNumberLen2 }, &net.TCPAddr{}, @@ -84,22 +84,23 @@ var _ = Describe("Packet packer", func() { }) Context("determining the maximum packet size", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} It("uses the minimum initial size, if it can't determine if the remote address is IPv4 or IPv6", func() { remoteAddr := &net.TCPAddr{} - packer = newPacketPacker(0x1337, 1, nil, remoteAddr, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) + packer = newPacketPacker(connID, 1, nil, remoteAddr, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MinInitialPacketSize)) }) It("uses the maximum IPv4 packet size, if the remote address is IPv4", func() { remoteAddr := &net.UDPAddr{IP: net.IPv4(11, 12, 13, 14), Port: 1337} - packer = newPacketPacker(0x1337, 1, nil, remoteAddr, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) + packer = newPacketPacker(connID, 1, nil, remoteAddr, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MaxPacketSizeIPv4)) }) It("uses the maximum IPv6 packet size, if the remote address is IPv6", func() { ip := net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334") remoteAddr := &net.UDPAddr{IP: ip, Port: 1337} - packer = newPacketPacker(0x1337, 1, nil, remoteAddr, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) + packer = newPacketPacker(connID, 1, nil, remoteAddr, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MaxPacketSizeIPv6)) }) }) diff --git a/server.go b/server.go index 1e56f0b9..f3ef1882 100644 --- a/server.go +++ b/server.go @@ -42,7 +42,7 @@ type server struct { scfg *handshake.ServerConfig sessionsMutex sync.RWMutex - sessions map[protocol.ConnectionID]packetHandler + sessions map[string] /* string(ConnectionID)*/ packetHandler closed bool serverError error @@ -106,7 +106,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, config: config, certChain: certChain, scfg: scfg, - sessions: map[protocol.ConnectionID]packetHandler{}, + sessions: map[string]packetHandler{}, newSession: newSession, deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, sessionQueue: make(chan Session, 5), @@ -144,11 +144,11 @@ func (s *server) setupTLS() error { connID := tlsSession.connID sess := tlsSession.sess s.sessionsMutex.Lock() - if _, ok := s.sessions[connID]; ok { // drop this session if it already exists + if _, ok := s.sessions[string(connID)]; ok { // drop this session if it already exists s.sessionsMutex.Unlock() continue } - s.sessions[connID] = sess + s.sessions[string(connID)] = sess s.sessionsMutex.Unlock() s.runHandshakeAndSession(sess, connID) } @@ -307,7 +307,10 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet } hdr.Raw = packet[:len(packet)-r.Len()] packetData := packet[len(packet)-r.Len():] - connID := hdr.ConnectionID + + if hdr.IsLongHeader && !hdr.DestConnectionID.Equal(hdr.SrcConnectionID) { + return errors.New("receiving packets with different destination and source connection IDs not supported") + } if hdr.Type == protocol.PacketTypeInitial { if s.supportsTLS { @@ -317,7 +320,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet } s.sessionsMutex.RLock() - session, sessionKnown := s.sessions[connID] + session, sessionKnown := s.sessions[string(hdr.DestConnectionID)] s.sessionsMutex.RUnlock() if sessionKnown && session == nil { @@ -331,12 +334,12 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet var pr *wire.PublicReset pr, err = wire.ParsePublicReset(r) if err != nil { - s.logger.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.", hdr.ConnectionID) + s.logger.Infof("Received a Public Reset for connection %s. An error occurred parsing the packet.", hdr.DestConnectionID) } else { - s.logger.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) + s.logger.Infof("Received a Public Reset for connection %s, rejected packet number: 0x%x.", hdr.DestConnectionID, pr.RejectedPacketNumber) } } else { - s.logger.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) + s.logger.Infof("Received Public Reset for unknown connection %s.", hdr.DestConnectionID) } return nil } @@ -345,7 +348,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet // This should only happen after a server restart, when we still receive packets for connections that we lost the state for. // TODO(#943): implement sending of IETF draft style stateless resets if !sessionKnown && (!hdr.VersionFlag && hdr.Type != protocol.PacketTypeInitial) { - _, err = pconn.WriteTo(wire.WritePublicReset(connID, 0, 0), remoteAddr) + _, err = pconn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), remoteAddr) return err } @@ -364,7 +367,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return errors.New("dropping small packet with unknown version") } s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version) - _, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr) + _, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.SrcConnectionID, s.config.Versions), remoteAddr) return err } @@ -380,11 +383,11 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return errors.New("Server BUG: negotiated version not supported") } - s.logger.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr) + s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, remoteAddr) session, err = s.newSession( &conn{pconn: pconn, currentAddr: remoteAddr}, version, - hdr.ConnectionID, + hdr.DestConnectionID, s.scfg, s.tlsConf, s.config, @@ -394,10 +397,10 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return err } s.sessionsMutex.Lock() - s.sessions[connID] = session + s.sessions[string(hdr.DestConnectionID)] = session s.sessionsMutex.Unlock() - s.runHandshakeAndSession(session, connID) + s.runHandshakeAndSession(session, hdr.DestConnectionID) } session.handlePacket(&receivedPacket{ remoteAddr: remoteAddr, @@ -425,12 +428,12 @@ func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.C func (s *server) removeConnection(id protocol.ConnectionID) { s.sessionsMutex.Lock() - s.sessions[id] = nil + s.sessions[string(id)] = nil s.sessionsMutex.Unlock() time.AfterFunc(s.deleteClosedSessionsAfter, func() { s.sessionsMutex.Lock() - delete(s.sessions, id) + delete(s.sessions, string(id)) s.sessionsMutex.Unlock() }) } diff --git a/server_test.go b/server_test.go index ac2b8671..75331015 100644 --- a/server_test.go +++ b/server_test.go @@ -106,12 +106,12 @@ var _ = Describe("Server", func() { var ( serv *server firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID) - connID = protocol.ConnectionID(0x4cfa9f9b668619f6) + connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} ) BeforeEach(func() { serv = &server{ - sessions: make(map[protocol.ConnectionID]packetHandler), + sessions: make(map[string]packetHandler), newSession: newMockSession, conn: conn, config: config, @@ -174,13 +174,13 @@ var _ = Describe("Server", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - sess := serv.sessions[connID].(*mockSession) + sess := serv.sessions[string(connID)].(*mockSession) Expect(sess.connectionID).To(Equal(connID)) Expect(sess.packetCount).To(Equal(1)) }) It("accepts new TLS sessions", func() { - connID := protocol.ConnectionID(0x12345) + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) err = serv.setupTLS() @@ -192,12 +192,12 @@ var _ = Describe("Server", func() { Eventually(func() packetHandler { serv.sessionsMutex.Lock() defer serv.sessionsMutex.Unlock() - return serv.sessions[connID] + return serv.sessions[string(connID)] }).Should(Equal(sess)) }) It("only accepts one new TLS sessions for one connection ID", func() { - connID := protocol.ConnectionID(0x12345) + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) @@ -211,7 +211,7 @@ var _ = Describe("Server", func() { Eventually(func() packetHandler { serv.sessionsMutex.Lock() defer serv.sessionsMutex.Unlock() - return serv.sessions[connID] + return serv.sessions[string(connID)] }).Should(Equal(sess1)) serv.serverTLS.sessionChan <- tlsSession{ connID: connID, @@ -220,7 +220,7 @@ var _ = Describe("Server", func() { Eventually(func() packetHandler { serv.sessionsMutex.Lock() defer serv.sessionsMutex.Unlock() - return serv.sessions[connID] + return serv.sessions[string(connID)] }).Should(Equal(sess1)) }) @@ -235,7 +235,7 @@ var _ = Describe("Server", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - sess := serv.sessions[connID].(*mockSession) + sess := serv.sessions[string(connID)].(*mockSession) Consistently(func() Session { return acceptedSess }).Should(BeNil()) close(sess.handshakeChan) Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) @@ -252,7 +252,7 @@ var _ = Describe("Server", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - sess := serv.sessions[connID].(*mockSession) + sess := serv.sessions[string(connID)].(*mockSession) sess.handshakeChan <- errors.New("handshake failed") Consistently(func() bool { return accepted }).Should(BeFalse()) close(done) @@ -264,8 +264,8 @@ var _ = Describe("Server", func() { err = serv.handlePacket(nil, nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[connID].(*mockSession).connectionID).To(Equal(connID)) - Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(2)) + Expect(serv.sessions[string(connID)].(*mockSession).connectionID).To(Equal(connID)) + Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(2)) }) It("closes and deletes sessions", func() { @@ -275,12 +275,12 @@ var _ = Describe("Server", func() { err = serv.handlePacket(nil, nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...)) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[connID]).ToNot(BeNil()) + Expect(serv.sessions[string(connID)]).ToNot(BeNil()) // make session.run() return - serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{} + serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{} // The server should now have closed the session, leaving a nil value in the sessions map - Consistently(func() map[protocol.ConnectionID]packetHandler { return serv.sessions }).Should(HaveLen(1)) - Expect(serv.sessions[connID]).To(BeNil()) + Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1)) + Expect(serv.sessions[string(connID)]).To(BeNil()) }) It("deletes nil session entries after a wait time", func() { @@ -290,12 +290,12 @@ var _ = Describe("Server", func() { err = serv.handlePacket(nil, nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...)) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions).To(HaveKey(connID)) + Expect(serv.sessions).To(HaveKey(string(connID))) // make session.run() return - serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{} + serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{} Eventually(func() bool { serv.sessionsMutex.Lock() - _, ok := serv.sessions[connID] + _, ok := serv.sessions[string(connID)] serv.sessionsMutex.Unlock() return ok }).Should(BeFalse()) @@ -303,8 +303,8 @@ var _ = Describe("Server", func() { It("closes sessions and the connection when Close is called", func() { go serv.serve() - session, _ := newMockSession(nil, 0, 0, nil, nil, nil, nil) - serv.sessions[1] = session + session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil) + serv.sessions[string(connID)] = session err := serv.Close() Expect(err).NotTo(HaveOccurred()) Expect(session.(*mockSession).closed).To(BeTrue()) @@ -312,11 +312,11 @@ var _ = Describe("Server", func() { }) It("ignores packets for closed sessions", func() { - serv.sessions[connID] = nil + serv.sessions[string(connID)] = nil err := serv.handlePacket(nil, nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[connID]).To(BeNil()) + Expect(serv.sessions[string(connID)]).To(BeNil()) }) It("works if no quic.Config is given", func(done Done) { @@ -353,9 +353,9 @@ var _ = Describe("Server", func() { }, 0.5) It("closes all sessions when encountering a connection error", func() { - session, _ := newMockSession(nil, 0, 0, nil, nil, nil, nil) - serv.sessions[0x12345] = session - Expect(serv.sessions[0x12345].(*mockSession).closed).To(BeFalse()) + session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil) + serv.sessions[string(connID)] = session + Expect(serv.sessions[string(connID)].(*mockSession).closed).To(BeFalse()) testErr := errors.New("connection error") conn.readErr = testErr go serv.serve() @@ -366,7 +366,7 @@ var _ = Describe("Server", func() { It("ignores delayed packets with mismatching versions", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) b := &bytes.Buffer{} // add an unsupported version data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} @@ -377,7 +377,7 @@ var _ = Describe("Server", func() { // if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn Expect(conn.dataWritten.Bytes()).To(BeEmpty()) // make sure the packet was *not* passed to session.handlePacket() - Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) }) It("errors on invalid public header", func() { @@ -386,7 +386,7 @@ var _ = Describe("Server", func() { }) It("ignores public resets for unknown connections", func() { - err := serv.handlePacket(nil, nil, wire.WritePublicReset(999, 1, 1337)) + err := serv.handlePacket(nil, nil, wire.WritePublicReset([]byte{9, 9, 9, 9, 9, 9, 9, 9}, 1, 1337)) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(BeEmpty()) }) @@ -395,33 +395,34 @@ var _ = Describe("Server", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) err = serv.handlePacket(nil, nil, wire.WritePublicReset(connID, 1, 1337)) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) }) It("ignores invalid public resets for known connections", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) data := wire.WritePublicReset(connID, 1, 1337) err = serv.handlePacket(nil, nil, data[:len(data)-2]) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) }) It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() { config.Versions = []protocol.VersionNumber{99} b := &bytes.Buffer{} hdr := wire.Header{ - VersionFlag: true, - ConnectionID: 0x1337, - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, + VersionFlag: true, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen2, } hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO @@ -433,10 +434,11 @@ var _ = Describe("Server", func() { It("doesn't respond with a version negotiation packet if the first packet is too small", func() { b := &bytes.Buffer{} hdr := wire.Header{ - VersionFlag: true, - ConnectionID: 0x1337, - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, + VersionFlag: true, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen2, } hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small @@ -507,12 +509,14 @@ var _ = Describe("Server", func() { }) It("sends a gQUIC Version Negotaion Packet, if the client sent a gQUIC Public Header", func() { + connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} b := &bytes.Buffer{} hdr := wire.Header{ - VersionFlag: true, - ConnectionID: 0x1337, - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, + VersionFlag: true, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen2, } hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO @@ -534,7 +538,8 @@ var _ = Describe("Server", func() { packet, err := wire.ParseHeaderSentByServer(r, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(packet.VersionFlag).To(BeTrue()) - Expect(packet.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) + Expect(packet.DestConnectionID).To(Equal(connID)) + Expect(packet.SrcConnectionID).To(Equal(connID)) Expect(r.Len()).To(BeZero()) Consistently(done).ShouldNot(BeClosed()) // make the go routine return @@ -543,14 +548,16 @@ var _ = Describe("Server", func() { }) It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() { + connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} config.Versions = append(config.Versions, protocol.VersionTLS) b := &bytes.Buffer{} hdr := wire.Header{ - Type: protocol.PacketTypeInitial, - IsLongHeader: true, - ConnectionID: 0x1337, - PacketNumber: 0x55, - Version: 0x1234, + Type: protocol.PacketTypeInitial, + IsLongHeader: true, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0x55, + Version: 0x1234, } err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) @@ -573,7 +580,8 @@ var _ = Describe("Server", func() { packet, err := wire.ParseHeaderSentByServer(r, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(packet.IsVersionNegotiation).To(BeTrue()) - Expect(packet.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) + Expect(packet.DestConnectionID).To(Equal(connID)) + Expect(packet.SrcConnectionID).To(Equal(connID)) Expect(r.Len()).To(BeZero()) Consistently(done).ShouldNot(BeClosed()) // make the go routine return @@ -582,13 +590,15 @@ var _ = Describe("Server", func() { }) It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() { + connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} b := &bytes.Buffer{} hdr := wire.Header{ - Type: protocol.PacketTypeInitial, - IsLongHeader: true, - ConnectionID: 0x1337, - PacketNumber: 0x55, - Version: protocol.VersionTLS, + Type: protocol.PacketTypeInitial, + IsLongHeader: true, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 0x55, + Version: protocol.VersionTLS, } err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) diff --git a/server_tls.go b/server_tls.go index 9f387409..011d520a 100644 --- a/server_tls.go +++ b/server_tls.go @@ -87,6 +87,7 @@ func newServerTLS( } func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) { + // TODO: add a check that DestConnID == SrcConnID s.logger.Debugf("Received a Packet. Handling it statelessly.") sess, err := s.handleInitialImpl(remoteAddr, hdr, data) if err != nil { @@ -97,7 +98,7 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data [] return } s.sessionChan <- tlsSession{ - connID: hdr.ConnectionID, + connID: hdr.DestConnectionID, sess: sess, } } @@ -116,11 +117,12 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea ReasonPhrase: closeErr.Error(), } replyHdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - ConnectionID: clientHdr.ConnectionID, // echo the client's connection ID - PacketNumber: 1, // random packet number - Version: clientHdr.Version, + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + SrcConnectionID: clientHdr.SrcConnectionID, + DestConnectionID: clientHdr.DestConnectionID, + PacketNumber: 1, // random packet number + Version: clientHdr.Version, } data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer, s.logger) if err != nil { @@ -137,12 +139,16 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat // check version, if not matching send VNP if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) { s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) - _, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, s.supportedVersions), remoteAddr) + vnp, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.supportedVersions) + if err != nil { + return nil, err + } + _, err = s.conn.WriteTo(vnp, remoteAddr) return nil, err } // unpack packet and check stream frame contents - aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, hdr.Version) + aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.DestConnectionID, protocol.VersionTLS) if err != nil { return nil, err } @@ -170,15 +176,17 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, return nil, err } alert := tls.Handshake() + s.logger.Debugf("%#v\n", hdr) if alert == mint.AlertStatelessRetry { // the HelloRetryRequest was written to the bufferConn // Take that data and write send a Retry packet replyHdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - ConnectionID: hdr.ConnectionID, // echo the client's connection ID - PacketNumber: hdr.PacketNumber, // echo the client's packet number - Version: version, + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + DestConnectionID: hdr.DestConnectionID, + SrcConnectionID: hdr.SrcConnectionID, + PacketNumber: hdr.PacketNumber, // echo the client's packet number + Version: version, } f := &wire.StreamFrame{ StreamID: version.CryptoStreamID(), @@ -206,7 +214,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, params := <-paramsChan sess, err := newTLSServerSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, - hdr.ConnectionID, // TODO: we can use a server-chosen connection ID here + hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here protocol.PacketNumber(1), // TODO: use a random packet number here s.config, tls, diff --git a/server_tls_test.go b/server_tls_test.go index 633dcf9a..d38e56a0 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -46,16 +46,19 @@ var _ = Describe("Stateless TLS handling", func() { }) getPacket := func(f wire.Frame) (*wire.Header, []byte) { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} hdrBuf := &bytes.Buffer{} hdr := &wire.Header{ - IsLongHeader: true, - PacketNumber: 1, - Version: protocol.VersionTLS, + IsLongHeader: true, + DestConnectionID: connID, + SrcConnectionID: connID, + PacketNumber: 1, + Version: protocol.VersionTLS, } err := hdr.Write(hdrBuf, protocol.PerspectiveClient, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) hdr.Raw = hdrBuf.Bytes() - aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, 0, protocol.VersionTLS) + aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) buf := &bytes.Buffer{} err = f.Write(buf, protocol.VersionTLS) @@ -72,7 +75,7 @@ var _ = Describe("Stateless TLS handling", func() { hdr, err := wire.ParseHeaderSentByServer(r, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) hdr.Raw = data[:len(data)-r.Len()] - aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.ConnectionID, protocol.VersionTLS) + aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.DestConnectionID, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) payload, err := aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw) Expect(err).ToNot(HaveOccurred()) @@ -80,7 +83,12 @@ var _ = Describe("Stateless TLS handling", func() { } It("sends a version negotiation packet if it doesn't support the version", func() { - server.HandleInitial(nil, &wire.Header{Version: 0x1337}, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) + hdr := &wire.Header{ + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Version: 0x1337, + } + server.HandleInitial(nil, hdr, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) Expect(conn.dataWritten.Len()).ToNot(BeZero()) hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) diff --git a/session.go b/session.go index ed73637c..11b5d54c 100644 --- a/session.go +++ b/session.go @@ -598,9 +598,9 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) if s.logger.Debug() { if err != nil { - s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID) } else { - s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID, packet.encryptionLevel) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID, packet.encryptionLevel) } hdr.Log(s.logger) } @@ -800,7 +800,7 @@ func (s *session) handleCloseError(closeErr closeError) error { } // Don't log 'normal' reasons if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout { - s.logger.Infof("Closing connection %x", s.connectionID) + s.logger.Infof("Closing connection %s", s.connectionID) } else { s.logger.Errorf("Closing session with error: %s", closeErr.err.Error()) } @@ -1031,7 +1031,7 @@ func (s *session) logPacket(packet *packedPacket) { // We don't need to allocate the slices for calling the format functions return } - s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.header.PacketNumber, len(packet.raw), s.connectionID, packet.encryptionLevel) + s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.connectionID, packet.encryptionLevel) packet.header.Log(s.logger) for _, frame := range packet.frames { wire.LogFrame(s.logger, frame, true) @@ -1117,7 +1117,7 @@ func (s *session) newCryptoStream() cryptoStreamI { } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { - s.logger.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) + s.logger.Infof("Sending public reset for connection %s, packet number %d", s.connectionID, rejectedPacketNumber) return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0)) } diff --git a/session_test.go b/session_test.go index fd295660..6e66e2d2 100644 --- a/session_test.go +++ b/session_test.go @@ -107,7 +107,7 @@ var _ = Describe("Session", func() { pSess, err = newSession( mconn, protocol.Version39, - 0, + protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, scfg, nil, populateServerConfig(&Config{}), @@ -160,7 +160,7 @@ var _ = Describe("Session", func() { pSess, err := newSession( mconn, protocol.Version39, - 0, + protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, scfg, nil, conf, @@ -1732,7 +1732,7 @@ var _ = Describe("Client Session", func() { mconn, "hostname", protocol.Version39, - 0, + protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, nil, populateClientConfig(&Config{}), protocol.VersionWhatever,