diff --git a/client.go b/client.go index dc962c83..7101a67f 100644 --- a/client.go +++ b/client.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" @@ -23,14 +24,18 @@ type client struct { hostname string versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version - versionNegotiated bool // has version negotiation completed yet + versionNegotiated bool // has the server accepted our version receivedVersionNegotiationPacket bool + negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet tlsConf *tls.Config config *Config + tls handshake.MintTLS // only used when using TLS connectionID protocol.ConnectionID - version protocol.VersionNumber + + initialVersion protocol.VersionNumber + version protocol.VersionNumber session packetHandler } @@ -91,7 +96,6 @@ func DialNonFWSecure( if tlsConf != nil { hostname = tlsConf.ServerName } - if hostname == "" { hostname, _, err = net.SplitHostPort(host) if err != nil { @@ -111,8 +115,9 @@ func DialNonFWSecure( } utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) + go c.listen() - if err := c.establishSecureConnection(); err != nil { + if err := c.dial(); err != nil { return nil, err } return c.session.(NonFWSession), nil @@ -177,25 +182,79 @@ func populateClientConfig(config *Config) *Config { } } -// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure) -func (c *client) establishSecureConnection() error { - if err := c.createNewSession(c.version, nil); err != nil { +func (c *client) dial() error { + var err error + if c.version.UsesTLS() { + err = c.dialTLS() + } else { + err = c.dialGQUIC() + } + if err == errCloseSessionForNewVersion { + return c.dial() + } + return err +} + +func (c *client) dialGQUIC() error { + if err := c.createNewGQUICSession(); err != nil { return err } - go c.listen() + return c.establishSecureConnection() +} +func (c *client) dialTLS() error { + csc := handshake.NewCryptoStreamConn(nil) + mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient) + if err != nil { + return err + } + mintConf.ServerName = c.hostname + c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient) + params := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: protocol.MaxIncomingStreams, + IdleTimeout: c.config.IdleTimeout, + OmitConnectionID: c.config.RequestConnectionIDOmission, + } + eh := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version) + if err := c.tls.SetExtensionHandler(eh); err != nil { + return err + } + if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil { + return err + } + if err := c.establishSecureConnection(); err != nil { + if err != handshake.ErrCloseSessionForRetry { + return err + } + utils.Infof("Received a Retry packet. Recreating session.") + if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil { + return err + } + if err := c.establishSecureConnection(); err != nil { + return err + } + } + return nil +} + +// establishSecureConnection runs the session, and tries to establish a secure connection +// It returns: +// - errCloseSessionForNewVersion when the server sends a version negotiation packet +// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC) +// - any other error that might occur +// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC) +func (c *client) establishSecureConnection() error { var runErr error errorChan := make(chan struct{}) go func() { - // session.run() returns as soon as the session is closed - runErr = c.session.run() - if runErr == errCloseSessionForNewVersion { - // run the new session - runErr = c.session.run() - } + runErr = c.session.run() // returns as soon as the session is closed close(errorChan) utils.Infof("Connection %x closed.", c.connectionID) - c.conn.Close() + if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { + c.conn.Close() + } }() // wait until the server accepts the QUIC version (or an error occurs) @@ -219,7 +278,8 @@ func (c *client) establishSecureConnection() error { } } -// Listen listens +// Listen listens on the underlying connection and passes packets on for handling. +// It returns when the connection is closed. func (c *client) listen() { var err error @@ -233,13 +293,15 @@ func (c *client) listen() { n, addr, err = c.conn.Read(data) if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { - c.session.Close(err) + c.mutex.Lock() + if c.session != nil { + c.session.Close(err) + } + c.mutex.Unlock() } break } - data = data[:n] - - c.handlePacket(addr, data) + c.handlePacket(addr, data[:n]) } } @@ -257,15 +319,16 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission { return } - // reject packets with the wrong connection ID - if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID { - return - } hdr.Raw = packet[:len(packet)-r.Len()] c.mutex.Lock() defer c.mutex.Unlock() + // reject packets with the wrong connection ID + if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID { + return + } + if hdr.ResetFlag { cr := c.conn.RemoteAddr() // check if the remote address and the connection ID match @@ -305,6 +368,8 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { close(c.versionNegotiationChan) } + // TODO: validate packet number and connection ID on Retry packets (for IETF QUIC) + c.session.handlePacket(&receivedPacket{ remoteAddr: remoteAddr, header: hdr, @@ -323,15 +388,15 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { } } - c.receivedVersionNegotiationPacket = true - newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) if !ok { return qerr.InvalidVersion } + c.receivedVersionNegotiationPacket = true + c.negotiatedVersions = hdr.SupportedVersions // switch to negotiated version - initialVersion := c.version + c.initialVersion = c.version c.version = newVersion var err error c.connectionID, err = utils.GenerateConnectionID() @@ -339,17 +404,13 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { return err } utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) - - // create a new session and close the old one - // the new session must be created first to update client member variables - oldSession := c.session - defer oldSession.Close(errCloseSessionForNewVersion) - return c.createNewSession(initialVersion, hdr.SupportedVersions) + c.session.Close(errCloseSessionForNewVersion) + return nil } -func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error { - var err error - utils.Debugf("createNewSession with initial version %s", initialVersion) +func (c *client) createNewGQUICSession() (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() c.session, err = newClientSession( c.conn, c.hostname, @@ -357,8 +418,27 @@ func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotia c.connectionID, c.tlsConf, c.config, - initialVersion, - negotiatedVersions, + c.initialVersion, + c.negotiatedVersions, + ) + return err +} + +func (c *client) createNewTLSSession( + paramsChan <-chan handshake.TransportParameters, + version protocol.VersionNumber, +) (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.session, err = newTLSClientSession( + c.conn, + c.hostname, + c.version, + c.connectionID, + c.config, + c.tls, + paramsChan, + 1, ) return err } diff --git a/client_test.go b/client_test.go index 0b2438fa..ac7ae6dd 100644 --- a/client_test.go +++ b/client_test.go @@ -8,6 +8,7 @@ import ( "sync/atomic" "time" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" @@ -45,10 +46,9 @@ var _ = Describe("Client", func() { msess, _ := newMockSession(nil, 0, 0, nil, nil, nil) sess = msess.(*mockSession) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - packetConn = &mockPacketConn{ - addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}, - dataReadFrom: addr, - } + packetConn = newMockPacketConn() + packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} + packetConn.dataReadFrom = addr config = &Config{ Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78}, } @@ -87,7 +87,7 @@ var _ = Describe("Client", func() { _ protocol.VersionNumber, _ []protocol.VersionNumber, ) (packetHandler, error) { - Expect(conn.Write([]byte("fake CHLO"))).To(Succeed()) + Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed()) return sess, nil } origGenerateConnectionID = generateConnectionID @@ -101,7 +101,7 @@ var _ = Describe("Client", func() { }) It("dials non-forward-secure", func() { - packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) + packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) dialed := make(chan struct{}) go func() { defer GinkgoRecover() @@ -151,7 +151,7 @@ var _ = Describe("Client", func() { }) It("Dial only returns after the handshake is complete", func() { - packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) + packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) dialed := make(chan struct{}) go func() { defer GinkgoRecover() @@ -237,7 +237,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the connection to become secure", func() { testErr := errors.New("early handshake error") - packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) + packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -251,7 +251,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the handshake to complete", func() { testErr := errors.New("late handshake error") - packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) + packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -330,10 +330,6 @@ var _ = Describe("Client", func() { newVersion := protocol.VersionNumber(77) Expect(newVersion).ToNot(Equal(cl.version)) Expect(config.Versions).To(ContainElement(newVersion)) - packetConn.dataToRead = wire.ComposeGQUICVersionNegotiation( - cl.connectionID, - []protocol.VersionNumber{newVersion}, - ) sessionChan := make(chan *mockSession) handshakeChan := make(chan handshakeEvent) newClientSession = func( @@ -348,10 +344,7 @@ var _ = Describe("Client", func() { ) (packetHandler, error) { initialVersion = initialVersionP negotiatedVersions = negotiatedVersionsP - // make the server accept the new version - if len(negotiatedVersionsP) > 0 { - packetConn.dataToRead = acceptClientVersionPacket(connectionID) - } + sess := &mockSession{ connectionID: connectionID, stopRunLoop: make(chan struct{}), @@ -364,18 +357,26 @@ var _ = Describe("Client", func() { established := make(chan struct{}) go func() { defer GinkgoRecover() - err := cl.establishSecureConnection() + err := cl.dial() Expect(err).ToNot(HaveOccurred()) close(established) }() + go cl.listen() + actualInitialVersion := cl.version var firstSession, secondSession *mockSession Eventually(sessionChan).Should(Receive(&firstSession)) - Eventually(sessionChan).Should(Receive(&secondSession)) + packetConn.dataToRead <- wire.ComposeGQUICVersionNegotiation( + cl.connectionID, + []protocol.VersionNumber{newVersion}, + ) // it didn't pass the version negoation packet to the old session (since it has no payload) - Expect(firstSession.packetCount).To(BeZero()) Eventually(func() bool { return firstSession.closed }).Should(BeTrue()) Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion)) + Expect(firstSession.packetCount).To(BeZero()) + Eventually(sessionChan).Should(Receive(&secondSession)) + // make the server accept the new version + packetConn.dataToRead <- acceptClientVersionPacket(secondSession.connectionID) Consistently(func() bool { return secondSession.closed }).Should(BeFalse()) Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337)) Expect(negotiatedVersions).To(ContainElement(newVersion)) @@ -398,20 +399,23 @@ var _ = Describe("Client", func() { _ []protocol.VersionNumber, ) (packetHandler, error) { atomic.AddUint32(&sessionCounter, 1) - return sess, nil + return &mockSession{ + connectionID: connectionID, + stopRunLoop: make(chan struct{}), + }, nil } - go cl.establishSecureConnection() + go cl.dial() Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1)) newVersion := protocol.VersionNumber(77) Expect(newVersion).ToNot(Equal(cl.version)) Expect(config.Versions).To(ContainElement(newVersion)) cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) - Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2)) + Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) newVersion = protocol.VersionNumber(78) Expect(newVersion).ToNot(Equal(cl.version)) Expect(config.Versions).To(ContainElement(newVersion)) cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) - Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2)) + Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) }) It("errors if no matching version is found", func() { @@ -482,7 +486,7 @@ var _ = Describe("Client", func() { Expect(sess.closed).To(BeFalse()) }) - It("creates new sessions with the right parameters", func() { + It("creates new GQUIC sessions with the right parameters", func() { closeErr := errors.New("peer doesn't reply") c := make(chan struct{}) var cconn connection @@ -516,12 +520,84 @@ var _ = Describe("Client", func() { Eventually(c).Should(BeClosed()) Expect(cconn.(*conn).pconn).To(Equal(packetConn)) Expect(hostname).To(Equal("quic.clemente.io")) - Expect(version).To(Equal(cl.version)) + Expect(version).To(Equal(config.Versions[0])) Expect(conf.Versions).To(Equal(config.Versions)) sess.Close(closeErr) Eventually(dialed).Should(BeClosed()) }) + It("creates new TLS sessions with the right parameters", func() { + config.Versions = []protocol.VersionNumber{protocol.VersionTLS} + c := make(chan struct{}) + var cconn connection + var hostname string + var version protocol.VersionNumber + var conf *Config + newTLSClientSession = func( + connP connection, + hostnameP string, + versionP protocol.VersionNumber, + _ protocol.ConnectionID, + configP *Config, + tls handshake.MintTLS, + paramsChan <-chan handshake.TransportParameters, + _ protocol.PacketNumber, + ) (packetHandler, error) { + cconn = connP + hostname = hostnameP + version = versionP + conf = configP + close(c) + return sess, nil + } + dialed := make(chan struct{}) + go func() { + defer GinkgoRecover() + Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + close(dialed) + }() + Eventually(c).Should(BeClosed()) + Expect(cconn.(*conn).pconn).To(Equal(packetConn)) + Expect(hostname).To(Equal("quic.clemente.io")) + Expect(version).To(Equal(config.Versions[0])) + Expect(conf.Versions).To(Equal(config.Versions)) + sess.Close(errors.New("peer doesn't reply")) + Eventually(dialed).Should(BeClosed()) + }) + + It("creates a new session when the server performs a retry", func() { + config.Versions = []protocol.VersionNumber{protocol.VersionTLS} + sessionChan := make(chan *mockSession) + newTLSClientSession = func( + connP connection, + hostnameP string, + versionP protocol.VersionNumber, + _ protocol.ConnectionID, + configP *Config, + tls handshake.MintTLS, + paramsChan <-chan handshake.TransportParameters, + _ protocol.PacketNumber, + ) (packetHandler, error) { + sess := &mockSession{ + stopRunLoop: make(chan struct{}), + } + sessionChan <- sess + return sess, nil + } + dialed := make(chan struct{}) + go func() { + defer GinkgoRecover() + Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + close(dialed) + }() + var firstSession, secondSession *mockSession + Eventually(sessionChan).Should(Receive(&firstSession)) + firstSession.Close(handshake.ErrCloseSessionForRetry) + Eventually(sessionChan).Should(Receive(&secondSession)) + secondSession.Close(errors.New("stop test")) + Eventually(dialed).Should(BeClosed()) + }) + Context("handling packets", func() { It("handles packets", func() { ph := wire.Header{ @@ -532,7 +608,7 @@ var _ = Describe("Client", func() { b := &bytes.Buffer{} err := ph.Write(b, protocol.PerspectiveServer, cl.version) Expect(err).ToNot(HaveOccurred()) - packetConn.dataToRead = b.Bytes() + packetConn.dataToRead <- b.Bytes() Expect(sess.packetCount).To(BeZero()) stoppedListening := make(chan struct{}) diff --git a/conn_test.go b/conn_test.go index 764ca5fd..f7e00cac 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2,7 +2,7 @@ package quic import ( "bytes" - "io" + "errors" "net" "time" @@ -12,7 +12,7 @@ import ( type mockPacketConn struct { addr net.Addr - dataToRead []byte + dataToRead chan []byte dataReadFrom net.Addr readErr error dataWritten bytes.Buffer @@ -20,23 +20,34 @@ type mockPacketConn struct { closed bool } +func newMockPacketConn() *mockPacketConn { + return &mockPacketConn{ + dataToRead: make(chan []byte, 1000), + } +} + func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { if c.readErr != nil { return 0, nil, c.readErr } - if c.dataToRead == nil { // block if there's no data - time.Sleep(time.Hour) - return 0, nil, io.EOF + data, ok := <-c.dataToRead + if !ok { + return 0, nil, errors.New("connection closed") } - n := copy(b, c.dataToRead) - c.dataToRead = nil + n := copy(b, data) return n, c.dataReadFrom, nil } func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { c.dataWrittenTo = addr return c.dataWritten.Write(b) } -func (c *mockPacketConn) Close() error { c.closed = true; return nil } +func (c *mockPacketConn) Close() error { + if !c.closed { + close(c.dataToRead) + } + c.closed = true + return nil +} func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr } func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") } func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") } @@ -53,7 +64,7 @@ var _ = Describe("Connection", func() { IP: net.IPv4(192, 168, 100, 200), Port: 1337, } - packetConn = &mockPacketConn{} + packetConn = newMockPacketConn() c = &conn{ currentAddr: addr, pconn: packetConn, @@ -68,7 +79,7 @@ var _ = Describe("Connection", func() { }) It("reads", func() { - packetConn.dataToRead = []byte("foo") + packetConn.dataToRead <- []byte("foo") packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336} p := make([]byte, 10) n, raddr, err := c.Read(p) diff --git a/internal/handshake/cookie_handler.go b/internal/handshake/cookie_handler.go index 317f6e50..1d3052c4 100644 --- a/internal/handshake/cookie_handler.go +++ b/internal/handshake/cookie_handler.go @@ -7,33 +7,33 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -type cookieHandler struct { +type CookieHandler struct { callback func(net.Addr, *Cookie) bool cookieGenerator *CookieGenerator } -var _ mint.CookieHandler = &cookieHandler{} +var _ mint.CookieHandler = &CookieHandler{} -func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) { +func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) { cookieGenerator, err := NewCookieGenerator() if err != nil { return nil, err } - return &cookieHandler{ + return &CookieHandler{ callback: callback, cookieGenerator: cookieGenerator, }, nil } -func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) { +func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) { if h.callback(conn.RemoteAddr(), nil) { return nil, nil } return h.cookieGenerator.NewToken(conn.RemoteAddr()) } -func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool { +func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool { data, err := h.cookieGenerator.DecodeToken(token) if err != nil { utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) diff --git a/internal/handshake/cookie_handler_test.go b/internal/handshake/cookie_handler_test.go index 9c95ea61..16b9207e 100644 --- a/internal/handshake/cookie_handler_test.go +++ b/internal/handshake/cookie_handler_test.go @@ -2,6 +2,7 @@ package handshake import ( "net" + "time" "github.com/bifurcation/mint" @@ -9,22 +10,37 @@ import ( . "github.com/onsi/gomega" ) +type mockConn struct { + remoteAddr net.Addr +} + +var _ net.Conn = &mockConn{} + +func (c *mockConn) Read([]byte) (int, error) { panic("not implemented") } +func (c *mockConn) Write([]byte) (int, error) { panic("not implemented") } +func (c *mockConn) Close() error { panic("not implemented") } +func (c *mockConn) LocalAddr() net.Addr { panic("not implemented") } +func (c *mockConn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *mockConn) SetReadDeadline(time.Time) error { panic("not implemented") } +func (c *mockConn) SetWriteDeadline(time.Time) error { panic("not implemented") } +func (c *mockConn) SetDeadline(time.Time) error { panic("not implemented") } + var callbackReturn bool var mockCallback = func(net.Addr, *Cookie) bool { return callbackReturn } var _ = Describe("Cookie Handler", func() { - var ch *cookieHandler + var ch *CookieHandler var conn *mint.Conn BeforeEach(func() { callbackReturn = false var err error - ch, err = newCookieHandler(mockCallback) + ch, err = NewCookieHandler(mockCallback) Expect(err).ToNot(HaveOccurred()) addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46} - conn = mint.NewConn(&fakeConn{remoteAddr: addr}, &mint.Config{}, false) + conn = mint.NewConn(&mockConn{remoteAddr: addr}, &mint.Config{}, false) }) It("generates and validates a token", func() { diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index 78ec12b9..921c005f 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -381,10 +381,6 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) { h.divNonceChan <- data } -func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType { - panic("not needed for cryptoSetupServer") -} - func (h *cryptoSetupClient) sendCHLO() error { h.clientHelloCounter++ if h.clientHelloCounter > protocol.MaxClientHellos { diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index ddaf75e9..5df110ff 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -458,10 +458,6 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) { panic("not needed for cryptoSetupServer") } -func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType { - panic("not needed for cryptoSetupServer") -} - func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { if len(nonce) != 32 { return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index abaee16b..041c0b42 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -1,10 +1,9 @@ package handshake import ( - "crypto/tls" + "errors" "fmt" "io" - "net" "sync" "github.com/bifurcation/mint" @@ -12,6 +11,9 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" ) +// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry +var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry") + // KeyDerivationFunction is used for key derivation type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) @@ -20,68 +22,31 @@ type cryptoSetupTLS struct { perspective protocol.Perspective - tls mintTLS - conn *fakeConn - - nextPacketType protocol.PacketType - keyDerivation KeyDerivationFunction nullAEAD crypto.AEAD aead crypto.AEAD - aeadChanged chan<- protocol.EncryptionLevel + tls MintTLS + cryptoStream *CryptoStreamConn + aeadChanged chan<- protocol.EncryptionLevel } // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server func NewCryptoSetupTLSServer( - cryptoStream io.ReadWriter, - connID protocol.ConnectionID, - tlsConfig *tls.Config, - remoteAddr net.Addr, - params *TransportParameters, - paramsChan chan<- TransportParameters, + tls MintTLS, + cryptoStream *CryptoStreamConn, + nullAEAD crypto.AEAD, aeadChanged chan<- protocol.EncryptionLevel, - checkCookie func(net.Addr, *Cookie) bool, - supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, -) (CryptoSetup, error) { - mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer) - if err != nil { - return nil, err - } - mintConf.RequireCookie = true - mintConf.CookieHandler, err = newCookieHandler(checkCookie) - if err != nil { - return nil, err - } - mintConf.CookieProtector, err = mint.NewDefaultCookieProtector() - if err != nil { - return nil, err - } - conn := &fakeConn{ - stream: cryptoStream, - pers: protocol.PerspectiveServer, - remoteAddr: remoteAddr, - } - mintConn := mint.Server(conn, mintConf) - eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version) - if err := mintConn.SetExtensionHandler(eh); err != nil { - return nil, err - } - - nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) - if err != nil { - return nil, err - } - +) CryptoSetup { return &cryptoSetupTLS{ - perspective: protocol.PerspectiveServer, - tls: &mintController{mintConn}, - conn: conn, + tls: tls, + cryptoStream: cryptoStream, nullAEAD: nullAEAD, + perspective: protocol.PerspectiveServer, keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged, - }, nil + } } // NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client @@ -89,60 +54,44 @@ func NewCryptoSetupTLSClient( cryptoStream io.ReadWriter, connID protocol.ConnectionID, hostname string, - tlsConfig *tls.Config, - params *TransportParameters, - paramsChan chan<- TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, - initialVersion protocol.VersionNumber, - supportedVersions []protocol.VersionNumber, + tls MintTLS, version protocol.VersionNumber, ) (CryptoSetup, error) { - mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient) - if err != nil { - return nil, err - } - mintConf.ServerName = hostname - conn := &fakeConn{ - stream: cryptoStream, - pers: protocol.PerspectiveClient, - } - mintConn := mint.Client(conn, mintConf) - eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version) - if err := mintConn.SetExtensionHandler(eh); err != nil { - return nil, err - } - nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) if err != nil { return nil, err } return &cryptoSetupTLS{ - conn: conn, - perspective: protocol.PerspectiveClient, - tls: &mintController{mintConn}, - nullAEAD: nullAEAD, - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, - nextPacketType: protocol.PacketTypeInitial, + perspective: protocol.PerspectiveClient, + tls: tls, + nullAEAD: nullAEAD, + keyDerivation: crypto.DeriveAESKeys, + aeadChanged: aeadChanged, }, nil } func (h *cryptoSetupTLS) HandleCryptoStream() error { + if h.perspective == protocol.PerspectiveServer { + // mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer + // send out that data now + if _, err := h.cryptoStream.Flush(); err != nil { + return err + } + } + handshakeLoop: for { - switch alert := h.tls.Handshake(); alert { - case mint.AlertStatelessRetry: - case mint.AlertNoAlert: // handshake complete - break handshakeLoop - case mint.AlertWouldBlock: - h.determineNextPacketType() - if err := h.conn.Continue(); err != nil { - return err - } - default: + if alert := h.tls.Handshake(); alert != mint.AlertNoAlert { return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) } + switch h.tls.State() { + case mint.StateClientStart: // this happens if a stateless retry is performed + return ErrCloseSessionForRetry + case mint.StateClientConnected, mint.StateServerConnected: + break handshakeLoop + } } aead, err := h.keyDerivation(h.tls, h.perspective) @@ -209,35 +158,6 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S return protocol.EncryptionUnencrypted, h.nullAEAD } -func (h *cryptoSetupTLS) determineNextPacketType() error { - h.mutex.Lock() - defer h.mutex.Unlock() - state := h.tls.State().HandshakeState - if h.perspective == protocol.PerspectiveServer { - switch state { - case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest - h.nextPacketType = protocol.PacketTypeRetry - case "ServerStateWaitFinished": - h.nextPacketType = protocol.PacketTypeHandshake - default: - // TODO: accept 0-RTT data - return fmt.Errorf("Unexpected handshake state: %s", state) - } - return nil - } - // client - if state != "ClientStateWaitSH" { - h.nextPacketType = protocol.PacketTypeHandshake - } - return nil -} - -func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.nextPacketType -} - func (h *cryptoSetupTLS) DiversificationNonce() []byte { panic("diversification nonce not needed for TLS") } diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index d1ee39ba..03b486e6 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -1,7 +1,6 @@ package handshake import ( - "bytes" "errors" "fmt" @@ -10,7 +9,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/mocks/crypto" "github.com/lucas-clemente/quic-go/internal/mocks/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/testdata" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -23,52 +21,33 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e var _ = Describe("TLS Crypto Setup", func() { var ( cs *cryptoSetupTLS - paramsChan chan TransportParameters aeadChanged chan protocol.EncryptionLevel ) BeforeEach(func() { - paramsChan = make(chan TransportParameters) aeadChanged = make(chan protocol.EncryptionLevel, 2) - csInt, err := NewCryptoSetupTLSServer( + cs = NewCryptoSetupTLSServer( nil, - 1, - testdata.GetTLSConfig(), - nil, - &TransportParameters{}, - paramsChan, + NewCryptoStreamConn(nil), + nil, // AEAD aeadChanged, - nil, - nil, protocol.VersionTLS, - ) - Expect(err).ToNot(HaveOccurred()) - cs = csInt.(*cryptoSetupTLS) + ).(*cryptoSetupTLS) cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl) }) It("errors when the handshake fails", func() { alert := mint.AlertBadRecordMAC - cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) - cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(alert) + cs.tls = mockhandshake.NewMockMintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(alert) err := cs.HandleCryptoStream() Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert))) }) - It("continues shaking hands when mint says that it would block", func() { - cs.conn.stream = &bytes.Buffer{} - cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) - cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock) - cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{}) - cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) - cs.keyDerivation = mockKeyDerivation - err := cs.HandleCryptoStream() - Expect(err).ToNot(HaveOccurred()) - }) - It("derives keys", func() { - cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) - cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) + cs.tls = mockhandshake.NewMockMintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected) cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) @@ -76,64 +55,22 @@ var _ = Describe("TLS Crypto Setup", func() { Expect(aeadChanged).To(BeClosed()) }) - Context("determining the packet type", func() { - Context("for the client", func() { - var csClient *cryptoSetupTLS - - BeforeEach(func() { - csInt, err := NewCryptoSetupTLSClient( - nil, - 1, - "quic.clemente.io", - testdata.GetTLSConfig(), - &TransportParameters{}, - paramsChan, - aeadChanged, - protocol.VersionTLS, - []protocol.VersionNumber{protocol.VersionTLS}, - protocol.VersionTLS, - ) - Expect(err).ToNot(HaveOccurred()) - csClient = csInt.(*cryptoSetupTLS) - csClient.tls = mockhandshake.NewMockmintTLS(mockCtrl) - }) - - It("sends a Client Initial first", func() { - Expect(csClient.GetNextPacketType()).To(Equal(protocol.PacketTypeInitial)) - }) - - It("sends a Handshake packet after the server sent a Server Hello", func() { - csClient.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ClientStateWaitEE"}) - err := csClient.determineNextPacketType() - Expect(err).ToNot(HaveOccurred()) - }) - }) - - Context("for the server", func() { - BeforeEach(func() { - cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) - }) - - It("sends a Stateless Retry packet", func() { - cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateStart"}) - err := cs.determineNextPacketType() - Expect(err).ToNot(HaveOccurred()) - Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeRetry)) - }) - - It("sends Handshake packet", func() { - cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateWaitFinished"}) - err := cs.determineNextPacketType() - Expect(err).ToNot(HaveOccurred()) - Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeHandshake)) - }) - }) + It("handshakes until it is connected", func() { + cs.tls = mockhandshake.NewMockMintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerNegotiated).Times(9) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected) + cs.keyDerivation = mockKeyDerivation + err := cs.HandleCryptoStream() + Expect(err).ToNot(HaveOccurred()) + Expect(aeadChanged).To(Receive()) }) Context("escalating crypto", func() { doHandshake := func() { - cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) - cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) + cs.tls = mockhandshake.NewMockMintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected) cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) @@ -240,3 +177,33 @@ var _ = Describe("TLS Crypto Setup", func() { }) }) }) + +var _ = Describe("TLS Crypto Setup, for the client", func() { + var ( + cs *cryptoSetupTLS + aeadChanged chan protocol.EncryptionLevel + ) + + BeforeEach(func() { + aeadChanged = make(chan protocol.EncryptionLevel, 2) + csInt, err := NewCryptoSetupTLSClient( + nil, + 0, + "quic.clemente.io", + aeadChanged, + nil, // mintTLS + protocol.VersionTLS, + ) + Expect(err).ToNot(HaveOccurred()) + cs = csInt.(*cryptoSetupTLS) + }) + + It("returns when a retry is performed", func() { + cs.tls = mockhandshake.NewMockMintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateClientStart) + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(ErrCloseSessionForRetry)) + }) + +}) diff --git a/internal/handshake/crypto_stream_conn.go b/internal/handshake/crypto_stream_conn.go new file mode 100644 index 00000000..03825c41 --- /dev/null +++ b/internal/handshake/crypto_stream_conn.go @@ -0,0 +1,101 @@ +package handshake + +import ( + "bytes" + "io" + "net" + "time" +) + +// The CryptoStreamConn is used as the net.Conn passed to mint. +// It has two operating modes: +// 1. It can read and write to bytes.Buffers. +// 2. It can use a quic.Stream for reading and writing. +// The buffer-mode is only used by the server, in order to statelessly handle retries. +type CryptoStreamConn struct { + remoteAddr net.Addr + + // the buffers are used before the session is initialized + readBuf bytes.Buffer + writeBuf bytes.Buffer + + // stream will be set once the session is initialized + stream io.ReadWriter +} + +var _ net.Conn = &CryptoStreamConn{} + +// NewCryptoStreamConn creates a new CryptoStreamConn +func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn { + return &CryptoStreamConn{remoteAddr: remoteAddr} +} + +func (c *CryptoStreamConn) Read(b []byte) (int, error) { + if c.stream != nil { + return c.stream.Read(b) + } + return c.readBuf.Read(b) +} + +// AddDataForReading adds data to the read buffer. +// This data will ONLY be read when the stream has not been set. +func (c *CryptoStreamConn) AddDataForReading(data []byte) { + c.readBuf.Write(data) +} + +func (c *CryptoStreamConn) Write(p []byte) (int, error) { + if c.stream != nil { + return c.stream.Write(p) + } + return c.writeBuf.Write(p) +} + +// GetDataForWriting returns all data currently in the write buffer, and resets this buffer. +func (c *CryptoStreamConn) GetDataForWriting() []byte { + defer c.writeBuf.Reset() + data := make([]byte, c.writeBuf.Len()) + copy(data, c.writeBuf.Bytes()) + return data +} + +// SetStream sets the stream. +// After setting the stream, the read and write buffer won't be used any more. +func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) { + c.stream = stream +} + +// Flush copies the contents of the write buffer to the stream +func (c *CryptoStreamConn) Flush() (int, error) { + n, err := io.Copy(c.stream, &c.writeBuf) + return int(n), err +} + +// Close is not implemented +func (c *CryptoStreamConn) Close() error { + return nil +} + +// LocalAddr is not implemented +func (c *CryptoStreamConn) LocalAddr() net.Addr { + return nil +} + +// RemoteAddr returns the remote address +func (c *CryptoStreamConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +// SetReadDeadline is not implemented +func (c *CryptoStreamConn) SetReadDeadline(time.Time) error { + return nil +} + +// SetWriteDeadline is not implemented +func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error { + return nil +} + +// SetDeadline is not implemented +func (c *CryptoStreamConn) SetDeadline(time.Time) error { + return nil +} diff --git a/internal/handshake/crypto_stream_conn_test.go b/internal/handshake/crypto_stream_conn_test.go new file mode 100644 index 00000000..c4aa1dba --- /dev/null +++ b/internal/handshake/crypto_stream_conn_test.go @@ -0,0 +1,67 @@ +package handshake + +import ( + "bytes" + "net" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("CryptoStreamConn", func() { + var ( + csc *CryptoStreamConn + remoteAddr net.Addr + ) + + BeforeEach(func() { + remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + csc = NewCryptoStreamConn(remoteAddr) + }) + + It("reads from the read buffer, when no stream is set", func() { + csc.AddDataForReading([]byte("foobar")) + data := make([]byte, 4) + n, err := csc.Read(data) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(data).To(Equal([]byte("foob"))) + }) + + It("writes to the write buffer, when no stream is set", func() { + csc.Write([]byte("foo")) + Expect(csc.GetDataForWriting()).To(Equal([]byte("foo"))) + csc.Write([]byte("bar")) + Expect(csc.GetDataForWriting()).To(Equal([]byte("bar"))) + }) + + It("reads from the stream, if available", func() { + csc.stream = &bytes.Buffer{} + csc.stream.Write([]byte("foobar")) + data := make([]byte, 3) + n, err := csc.Read(data) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(data).To(Equal([]byte("foo"))) + }) + + It("writes to the stream, if available", func() { + stream := &bytes.Buffer{} + csc.SetStream(stream) + csc.Write([]byte("foobar")) + Expect(stream.Bytes()).To(Equal([]byte("foobar"))) + }) + + It("returns the remote address", func() { + Expect(csc.RemoteAddr()).To(Equal(remoteAddr)) + }) + + It("has unimplemented methods", func() { + Expect(csc.Close()).ToNot(HaveOccurred()) + Expect(csc.SetDeadline(time.Time{})).ToNot(HaveOccurred()) + Expect(csc.SetReadDeadline(time.Time{})).ToNot(HaveOccurred()) + Expect(csc.SetWriteDeadline(time.Time{})).ToNot(HaveOccurred()) + Expect(csc.LocalAddr()).To(BeNil()) + }) +}) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index c34c8f1f..fbb70062 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -1,6 +1,10 @@ package handshake import ( + "io" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" ) @@ -10,14 +14,33 @@ type Sealer interface { Overhead() int } +// A TLSExtensionHandler sends and received the QUIC TLS extension. +// It provides the parameters sent by the peer on a channel. +type TLSExtensionHandler interface { + Send(mint.HandshakeType, *mint.ExtensionList) error + Receive(mint.HandshakeType, *mint.ExtensionList) error + GetPeerParams() <-chan TransportParameters +} + +// MintTLS combines some methods needed to interact with mint. +type MintTLS interface { + crypto.TLSExporter + + // additional methods + Handshake() mint.Alert + State() mint.State + + SetCryptoStream(io.ReadWriter) + SetExtensionHandler(mint.AppExtensionHandler) error +} + // CryptoSetup is a crypto setup type CryptoSetup interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) HandleCryptoStream() error // TODO: clean up this interface - DiversificationNonce() []byte // only needed for cryptoSetupServer - SetDiversificationNonce([]byte) // only needed for cryptoSetupClient - GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer + DiversificationNonce() []byte // only needed for cryptoSetupServer + SetDiversificationNonce([]byte) // only needed for cryptoSetupClient GetSealer() (protocol.EncryptionLevel, Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) diff --git a/internal/handshake/mint_utils.go b/internal/handshake/mint_utils.go deleted file mode 100644 index 8c3a83bd..00000000 --- a/internal/handshake/mint_utils.go +++ /dev/null @@ -1,127 +0,0 @@ -package handshake - -import ( - "bytes" - gocrypto "crypto" - "crypto/tls" - "crypto/x509" - "io" - "net" - "time" - - "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/crypto" - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) { - mconf := &mint.Config{ - NonBlocking: true, - CipherSuites: []mint.CipherSuite{ - mint.TLS_AES_128_GCM_SHA256, - mint.TLS_AES_256_GCM_SHA384, - }, - } - if tlsConf != nil { - mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) - for i, certChain := range tlsConf.Certificates { - mconf.Certificates[i] = &mint.Certificate{ - Chain: make([]*x509.Certificate, len(certChain.Certificate)), - PrivateKey: certChain.PrivateKey.(gocrypto.Signer), - } - for j, cert := range certChain.Certificate { - c, err := x509.ParseCertificate(cert) - if err != nil { - return nil, err - } - mconf.Certificates[i].Chain[j] = c - } - } - } - if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil { - return nil, err - } - return mconf, nil -} - -type mintTLS interface { - // These two methods are the same as the crypto.TLSExporter interface. - // Cannot use embedding here, because mockgen source mode refuses to generate mocks then. - GetCipherSuite() mint.CipherSuiteParams - ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) - // additional methods - Handshake() mint.Alert - State() mint.ConnectionState -} - -var _ crypto.TLSExporter = (mintTLS)(nil) - -type mintController struct { - conn *mint.Conn -} - -var _ mintTLS = &mintController{} - -func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams { - return mc.conn.State().CipherSuite -} - -func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { - return mc.conn.ComputeExporter(label, context, keyLength) -} - -func (mc *mintController) Handshake() mint.Alert { - return mc.conn.Handshake() -} - -func (mc *mintController) State() mint.ConnectionState { - return mc.conn.State() -} - -// mint expects a net.Conn, but we're doing the handshake on a stream -// so we wrap a stream such that implements a net.Conn -type fakeConn struct { - stream io.ReadWriter - pers protocol.Perspective - remoteAddr net.Addr - - blockRead bool - writeBuffer bytes.Buffer -} - -var _ net.Conn = &fakeConn{} - -func (c *fakeConn) Read(b []byte) (int, error) { - if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock - return 0, nil - } - c.blockRead = true // block the next Read call - return c.stream.Read(b) -} - -func (c *fakeConn) Write(p []byte) (int, error) { - if c.pers == protocol.PerspectiveClient { - return c.stream.Write(p) - } - // Buffer all writes by the server. - // Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data. - return c.writeBuffer.Write(p) -} - -func (c *fakeConn) Continue() error { - c.blockRead = false - if c.pers == protocol.PerspectiveClient { - return nil - } - // write all contents of the write buffer to the stream. - _, err := c.stream.Write(c.writeBuffer.Bytes()) - c.writeBuffer.Reset() - return err -} - -func (c *fakeConn) Close() error { return nil } -func (c *fakeConn) LocalAddr() net.Addr { return nil } -func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *fakeConn) SetReadDeadline(time.Time) error { return nil } -func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil } -func (c *fakeConn) SetDeadline(time.Time) error { return nil } diff --git a/internal/handshake/mint_utils_test.go b/internal/handshake/mint_utils_test.go deleted file mode 100644 index 1bb1858a..00000000 --- a/internal/handshake/mint_utils_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package handshake - -import ( - "bytes" - "net" - - "github.com/lucas-clemente/quic-go/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Fake Conn", func() { - var ( - c *fakeConn - stream *bytes.Buffer - ) - - BeforeEach(func() { - stream = &bytes.Buffer{} - c = &fakeConn{stream: stream} - }) - - Context("Reading", func() { - It("doesn't return any new data after one Read call", func() { - stream.Write([]byte("foobar")) - b := make([]byte, 3) - _, err := c.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(b).To(Equal([]byte("foo"))) - n, err := c.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(BeZero()) - }) - - It("allows more Read calls after unblocking", func() { - stream.Write([]byte("foobar")) - b := make([]byte, 3) - _, err := c.Read(b) - Expect(err).ToNot(HaveOccurred()) - err = c.Continue() - Expect(err).ToNot(HaveOccurred()) - _, err = c.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(b).To(Equal([]byte("bar"))) - }) - }) - - Context("Writing", func() { - It("writes directly when acting as a client", func() { - c.pers = protocol.PerspectiveClient - _, err := c.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(stream.Bytes()).To(Equal([]byte("foobar"))) - }) - - It("only writes after flushing when acting as a server", func() { - c.pers = protocol.PerspectiveServer - _, err := c.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(stream.Bytes()).To(BeEmpty()) - err = c.Continue() - Expect(err).ToNot(HaveOccurred()) - Expect(stream.Bytes()).To(Equal([]byte("foobar"))) - }) - }) - - It("returns its remote address", func() { - addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - c.remoteAddr = addr - Expect(c.RemoteAddr()).To(Equal(addr)) - }) -}) diff --git a/internal/handshake/tls_extension_handler_client.go b/internal/handshake/tls_extension_handler_client.go index 506b0db2..f2dfa958 100644 --- a/internal/handshake/tls_extension_handler_client.go +++ b/internal/handshake/tls_extension_handler_client.go @@ -13,8 +13,8 @@ import ( ) type extensionHandlerClient struct { - params *TransportParameters - paramsChan chan<- TransportParameters + ourParams *TransportParameters + paramsChan chan TransportParameters initialVersion protocol.VersionNumber supportedVersions []protocol.VersionNumber @@ -22,16 +22,17 @@ type extensionHandlerClient struct { } var _ mint.AppExtensionHandler = &extensionHandlerClient{} +var _ TLSExtensionHandler = &extensionHandlerClient{} -func newExtensionHandlerClient( +func NewExtensionHandlerClient( params *TransportParameters, - paramsChan chan<- TransportParameters, initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, -) *extensionHandlerClient { +) TLSExtensionHandler { + paramsChan := make(chan TransportParameters, 1) return &extensionHandlerClient{ - params: params, + ourParams: params, paramsChan: paramsChan, initialVersion: initialVersion, supportedVersions: supportedVersions, @@ -46,7 +47,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi data, err := syntax.Marshal(clientHelloTransportParameters{ InitialVersion: uint32(h.initialVersion), - Parameters: h.params.getTransportParameters(), + Parameters: h.ourParams.getTransportParameters(), }) if err != nil { return err @@ -123,3 +124,7 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte h.paramsChan <- *params return nil } + +func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters { + return h.paramsChan +} diff --git a/internal/handshake/tls_extension_handler_client_test.go b/internal/handshake/tls_extension_handler_client_test.go index 22ee1eb4..52822b85 100644 --- a/internal/handshake/tls_extension_handler_client_test.go +++ b/internal/handshake/tls_extension_handler_client_test.go @@ -13,15 +13,12 @@ import ( var _ = Describe("TLS Extension Handler, for the client", func() { var ( - handler *extensionHandlerClient - el mint.ExtensionList - paramsChan chan TransportParameters + handler *extensionHandlerClient + el mint.ExtensionList ) BeforeEach(func() { - // use a buffered channel here, so that we don't have to receive concurrently when parsing a message - paramsChan = make(chan TransportParameters, 1) - handler = newExtensionHandlerClient(&TransportParameters{}, paramsChan, protocol.VersionWhatever, nil, protocol.VersionWhatever) + handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever).(*extensionHandlerClient) el = make(mint.ExtensionList, 0) }) @@ -81,7 +78,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() { err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) Expect(err).ToNot(HaveOccurred()) var params TransportParameters - Expect(paramsChan).To(Receive(¶ms)) + Expect(handler.GetPeerParams()).To(Receive(¶ms)) Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344)) }) diff --git a/internal/handshake/tls_extension_handler_server.go b/internal/handshake/tls_extension_handler_server.go index d8fda2a5..941f5115 100644 --- a/internal/handshake/tls_extension_handler_server.go +++ b/internal/handshake/tls_extension_handler_server.go @@ -14,26 +14,27 @@ import ( ) type extensionHandlerServer struct { - params *TransportParameters - paramsChan chan<- TransportParameters + ourParams *TransportParameters + paramsChan chan TransportParameters version protocol.VersionNumber supportedVersions []protocol.VersionNumber } var _ mint.AppExtensionHandler = &extensionHandlerServer{} +var _ TLSExtensionHandler = &extensionHandlerServer{} -func newExtensionHandlerServer( +func NewExtensionHandlerServer( params *TransportParameters, - paramsChan chan<- TransportParameters, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, -) *extensionHandlerServer { +) TLSExtensionHandler { + paramsChan := make(chan TransportParameters, 1) return &extensionHandlerServer{ - params: params, + ourParams: params, paramsChan: paramsChan, - version: version, supportedVersions: supportedVersions, + version: version, } } @@ -43,7 +44,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi } transportParams := append( - h.params.getTransportParameters(), + h.ourParams.getTransportParameters(), // TODO(#855): generate a real token transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)}, ) @@ -105,3 +106,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte h.paramsChan <- *params return nil } + +func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters { + return h.paramsChan +} diff --git a/internal/handshake/tls_extension_handler_server_test.go b/internal/handshake/tls_extension_handler_server_test.go index 41a65161..ceab29b1 100644 --- a/internal/handshake/tls_extension_handler_server_test.go +++ b/internal/handshake/tls_extension_handler_server_test.go @@ -20,15 +20,12 @@ func parameterMapToList(paramMap map[transportParameterID][]byte) []transportPar var _ = Describe("TLS Extension Handler, for the server", func() { var ( - handler *extensionHandlerServer - el mint.ExtensionList - paramsChan chan TransportParameters + handler *extensionHandlerServer + el mint.ExtensionList ) BeforeEach(func() { - // use a buffered channel here, so that we don't have to receive concurrently when parsing a message - paramsChan = make(chan TransportParameters, 1) - handler = newExtensionHandlerServer(&TransportParameters{}, paramsChan, nil, protocol.VersionWhatever) + handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever).(*extensionHandlerServer) el = make(mint.ExtensionList, 0) }) @@ -91,7 +88,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() { err := handler.Receive(mint.HandshakeTypeClientHello, &el) Expect(err).ToNot(HaveOccurred()) var params TransportParameters - Expect(paramsChan).To(Receive(¶ms)) + Expect(handler.GetPeerParams()).To(Receive(¶ms)) Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344)) }) diff --git a/internal/mocks/gen.go b/internal/mocks/gen.go index e44c04d0..0f53b7a2 100644 --- a/internal/mocks/gen.go +++ b/internal/mocks/gen.go @@ -1,6 +1,7 @@ package mocks -//go:generate sh -c "mockgen -source=../handshake/mint_utils.go -package mockhandshake -destination handshake/mint_tls.go" +//go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS" +//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler" //go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" //go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController" //go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD" diff --git a/internal/mocks/handshake/mint_tls.go b/internal/mocks/handshake/mint_tls.go index b03c3d4f..b0da05b6 100644 --- a/internal/mocks/handshake/mint_tls.go +++ b/internal/mocks/handshake/mint_tls.go @@ -1,83 +1,106 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ../handshake/mint_utils.go +// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS) package mockhandshake import ( + io "io" reflect "reflect" mint "github.com/bifurcation/mint" gomock "github.com/golang/mock/gomock" ) -// MockmintTLS is a mock of mintTLS interface -type MockmintTLS struct { +// MockMintTLS is a mock of MintTLS interface +type MockMintTLS struct { ctrl *gomock.Controller - recorder *MockmintTLSMockRecorder + recorder *MockMintTLSMockRecorder } -// MockmintTLSMockRecorder is the mock recorder for MockmintTLS -type MockmintTLSMockRecorder struct { - mock *MockmintTLS +// MockMintTLSMockRecorder is the mock recorder for MockMintTLS +type MockMintTLSMockRecorder struct { + mock *MockMintTLS } -// NewMockmintTLS creates a new mock instance -func NewMockmintTLS(ctrl *gomock.Controller) *MockmintTLS { - mock := &MockmintTLS{ctrl: ctrl} - mock.recorder = &MockmintTLSMockRecorder{mock} +// NewMockMintTLS creates a new mock instance +func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS { + mock := &MockMintTLS{ctrl: ctrl} + mock.recorder = &MockMintTLSMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use -func (_m *MockmintTLS) EXPECT() *MockmintTLSMockRecorder { +func (_m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder { return _m.recorder } -// GetCipherSuite mocks base method -func (_m *MockmintTLS) GetCipherSuite() mint.CipherSuiteParams { - ret := _m.ctrl.Call(_m, "GetCipherSuite") - ret0, _ := ret[0].(mint.CipherSuiteParams) - return ret0 -} - -// GetCipherSuite indicates an expected call of GetCipherSuite -func (_mr *MockmintTLSMockRecorder) GetCipherSuite() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockmintTLS)(nil).GetCipherSuite)) -} - // ComputeExporter mocks base method -func (_m *MockmintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { - ret := _m.ctrl.Call(_m, "ComputeExporter", label, context, keyLength) +func (_m *MockMintTLS) ComputeExporter(_param0 string, _param1 []byte, _param2 int) ([]byte, error) { + ret := _m.ctrl.Call(_m, "ComputeExporter", _param0, _param1, _param2) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // ComputeExporter indicates an expected call of ComputeExporter -func (_mr *MockmintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockmintTLS)(nil).ComputeExporter), arg0, arg1, arg2) +func (_mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2) +} + +// GetCipherSuite mocks base method +func (_m *MockMintTLS) GetCipherSuite() mint.CipherSuiteParams { + ret := _m.ctrl.Call(_m, "GetCipherSuite") + ret0, _ := ret[0].(mint.CipherSuiteParams) + return ret0 +} + +// GetCipherSuite indicates an expected call of GetCipherSuite +func (_mr *MockMintTLSMockRecorder) GetCipherSuite() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockMintTLS)(nil).GetCipherSuite)) } // Handshake mocks base method -func (_m *MockmintTLS) Handshake() mint.Alert { +func (_m *MockMintTLS) Handshake() mint.Alert { ret := _m.ctrl.Call(_m, "Handshake") ret0, _ := ret[0].(mint.Alert) return ret0 } // Handshake indicates an expected call of Handshake -func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake)) +func (_mr *MockMintTLSMockRecorder) Handshake() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake)) +} + +// SetCryptoStream mocks base method +func (_m *MockMintTLS) SetCryptoStream(_param0 io.ReadWriter) { + _m.ctrl.Call(_m, "SetCryptoStream", _param0) +} + +// SetCryptoStream indicates an expected call of SetCryptoStream +func (_mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0) +} + +// SetExtensionHandler mocks base method +func (_m *MockMintTLS) SetExtensionHandler(_param0 mint.AppExtensionHandler) error { + ret := _m.ctrl.Call(_m, "SetExtensionHandler", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetExtensionHandler indicates an expected call of SetExtensionHandler +func (_mr *MockMintTLSMockRecorder) SetExtensionHandler(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetExtensionHandler", reflect.TypeOf((*MockMintTLS)(nil).SetExtensionHandler), arg0) } // State mocks base method -func (_m *MockmintTLS) State() mint.ConnectionState { +func (_m *MockMintTLS) State() mint.State { ret := _m.ctrl.Call(_m, "State") - ret0, _ := ret[0].(mint.ConnectionState) + ret0, _ := ret[0].(mint.State) return ret0 } // State indicates an expected call of State -func (_mr *MockmintTLSMockRecorder) State() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockmintTLS)(nil).State)) +func (_mr *MockMintTLSMockRecorder) State() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockMintTLS)(nil).State)) } diff --git a/internal/mocks/tls_extension_handler.go b/internal/mocks/tls_extension_handler.go new file mode 100644 index 00000000..3941db98 --- /dev/null +++ b/internal/mocks/tls_extension_handler.go @@ -0,0 +1,71 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler) + +package mocks + +import ( + reflect "reflect" + + mint "github.com/bifurcation/mint" + gomock "github.com/golang/mock/gomock" + handshake "github.com/lucas-clemente/quic-go/internal/handshake" +) + +// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface +type MockTLSExtensionHandler struct { + ctrl *gomock.Controller + recorder *MockTLSExtensionHandlerMockRecorder +} + +// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler +type MockTLSExtensionHandlerMockRecorder struct { + mock *MockTLSExtensionHandler +} + +// NewMockTLSExtensionHandler creates a new mock instance +func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler { + mock := &MockTLSExtensionHandler{ctrl: ctrl} + mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder { + return _m.recorder +} + +// GetPeerParams mocks base method +func (_m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters { + ret := _m.ctrl.Call(_m, "GetPeerParams") + ret0, _ := ret[0].(<-chan handshake.TransportParameters) + return ret0 +} + +// GetPeerParams indicates an expected call of GetPeerParams +func (_mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams)) +} + +// Receive mocks base method +func (_m *MockTLSExtensionHandler) Receive(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error { + ret := _m.ctrl.Call(_m, "Receive", _param0, _param1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Receive indicates an expected call of Receive +func (_mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1) +} + +// Send mocks base method +func (_m *MockTLSExtensionHandler) Send(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error { + ret := _m.ctrl.Call(_m, "Send", _param0, _param1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send +func (_mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1) +} diff --git a/mint_utils.go b/mint_utils.go new file mode 100644 index 00000000..02bb3aef --- /dev/null +++ b/mint_utils.go @@ -0,0 +1,150 @@ +package quic + +import ( + "bytes" + gocrypto "crypto" + "crypto/tls" + "crypto/x509" + "errors" + "io" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type mintController struct { + csc *handshake.CryptoStreamConn + conn *mint.Conn +} + +var _ handshake.MintTLS = &mintController{} + +func newMintController( + csc *handshake.CryptoStreamConn, + mconf *mint.Config, + pers protocol.Perspective, +) handshake.MintTLS { + var conn *mint.Conn + if pers == protocol.PerspectiveClient { + conn = mint.Client(csc, mconf) + } else { + conn = mint.Server(csc, mconf) + } + return &mintController{ + csc: csc, + conn: conn, + } +} + +func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams { + return mc.conn.State().CipherSuite +} + +func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { + return mc.conn.ComputeExporter(label, context, keyLength) +} + +func (mc *mintController) Handshake() mint.Alert { + return mc.conn.Handshake() +} + +func (mc *mintController) State() mint.State { + return mc.conn.State().HandshakeState +} + +func (mc *mintController) SetCryptoStream(stream io.ReadWriter) { + mc.csc.SetStream(stream) +} + +func (mc *mintController) SetExtensionHandler(h mint.AppExtensionHandler) error { + return mc.conn.SetExtensionHandler(h) +} + +func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) { + mconf := &mint.Config{ + NonBlocking: true, + CipherSuites: []mint.CipherSuite{ + mint.TLS_AES_128_GCM_SHA256, + mint.TLS_AES_256_GCM_SHA384, + }, + } + if tlsConf != nil { + mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) + for i, certChain := range tlsConf.Certificates { + mconf.Certificates[i] = &mint.Certificate{ + Chain: make([]*x509.Certificate, len(certChain.Certificate)), + PrivateKey: certChain.PrivateKey.(gocrypto.Signer), + } + for j, cert := range certChain.Certificate { + c, err := x509.ParseCertificate(cert) + if err != nil { + return nil, err + } + mconf.Certificates[i].Chain[j] = c + } + } + } + if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil { + return nil, err + } + return mconf, nil +} + +// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets +// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0. +func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.StreamFrame, error) { + unpacker := &packetUnpacker{aead: &nullAEAD{aead}, version: version} + packet, err := unpacker.Unpack(hdr.Raw, hdr, data) + if err != nil { + return nil, err + } + var frame *wire.StreamFrame + for _, f := range packet.frames { + var ok bool + frame, ok = f.(*wire.StreamFrame) + if ok { + break + } + } + if frame == nil { + return nil, errors.New("Packet doesn't contain a STREAM_FRAME") + } + // We don't need a check for the stream ID here. + // The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream. + if frame.Offset != 0 { + return nil, errors.New("received stream data with non-zero offset") + } + if utils.Debug() { + utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) + hdr.Log() + wire.LogFrame(frame, false) + } + return frame, nil +} + +// packUnencryptedPacket provides a low-overhead way to pack a packet. +// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. +func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) { + raw := getPacketBuffer() + buffer := bytes.NewBuffer(raw) + if err := hdr.Write(buffer, pers, hdr.Version); err != nil { + return nil, err + } + payloadStartIndex := buffer.Len() + if err := sf.Write(buffer, hdr.Version); err != nil { + return nil, err + } + raw = raw[0:buffer.Len()] + _ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex]) + raw = raw[0 : buffer.Len()+aead.Overhead()] + if utils.Debug() { + utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) + hdr.Log() + wire.LogFrame(sf, true) + } + return raw, nil +} diff --git a/mint_utils_test.go b/mint_utils_test.go new file mode 100644 index 00000000..e538cad2 --- /dev/null +++ b/mint_utils_test.go @@ -0,0 +1,111 @@ +package quic + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Packing and unpacking Initial packets", func() { + var aead crypto.AEAD + connID := protocol.ConnectionID(0x1337) + ver := protocol.VersionTLS + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + PacketNumber: 0x42, + ConnectionID: connID, + Version: ver, + } + + BeforeEach(func() { + var err error + aead, err = crypto.NewNullAEAD(protocol.PerspectiveServer, connID, ver) + Expect(err).ToNot(HaveOccurred()) + // set hdr.Raw + buf := &bytes.Buffer{} + err = hdr.Write(buf, protocol.PerspectiveServer, ver) + Expect(err).ToNot(HaveOccurred()) + hdr.Raw = buf.Bytes() + }) + + Context("unpacking", func() { + packPacket := func(frames []wire.Frame) []byte { + buf := &bytes.Buffer{} + err := hdr.Write(buf, protocol.PerspectiveClient, ver) + Expect(err).ToNot(HaveOccurred()) + payloadStartIndex := buf.Len() + aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver) + for _, f := range frames { + err := f.Write(buf, ver) + Expect(err).ToNot(HaveOccurred()) + } + raw := buf.Bytes() + return aeadCl.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex]) + } + + It("unpacks a packet", func() { + f := &wire.StreamFrame{ + StreamID: 0, + Data: []byte("foobar"), + } + p := packPacket([]wire.Frame{f}) + frame, err := unpackInitialPacket(aead, hdr, p, ver) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("rejects a packet that doesn't contain a STREAM_FRAME", func() { + p := packPacket([]wire.Frame{&wire.PingFrame{}}) + _, err := unpackInitialPacket(aead, hdr, p, ver) + Expect(err).To(MatchError("Packet doesn't contain a STREAM_FRAME")) + }) + + It("rejects a packet that has a STREAM_FRAME for the wrong stream", func() { + f := &wire.StreamFrame{ + StreamID: 42, + Data: []byte("foobar"), + } + p := packPacket([]wire.Frame{f}) + _, err := unpackInitialPacket(aead, hdr, p, ver) + Expect(err).To(MatchError("UnencryptedStreamData: received unencrypted stream data on stream 42")) + }) + + It("rejects a packet that has a STREAM_FRAME with a non-zero offset", func() { + f := &wire.StreamFrame{ + StreamID: 0, + Offset: 10, + Data: []byte("foobar"), + } + p := packPacket([]wire.Frame{f}) + _, err := unpackInitialPacket(aead, hdr, p, ver) + Expect(err).To(MatchError("received stream data with non-zero offset")) + }) + }) + + Context("packing", func() { + var unpacker *packetUnpacker + + BeforeEach(func() { + aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver) + Expect(err).ToNot(HaveOccurred()) + unpacker = &packetUnpacker{aead: &nullAEAD{aeadCl}, version: ver} + }) + + It("packs a packet", func() { + f := &wire.StreamFrame{ + Data: []byte("foobar"), + FinBit: true, + } + data, err := packUnencryptedPacket(aead, hdr, f, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + packet, err := unpacker.Unpack(hdr.Raw, hdr, data[len(hdr.Raw):]) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]wire.Frame{f})) + }) + }) +}) diff --git a/packet_number_generator.go b/packet_number_generator.go index 8ece95ac..ac635776 100644 --- a/packet_number_generator.go +++ b/packet_number_generator.go @@ -17,9 +17,9 @@ type packetNumberGenerator struct { nextToSkip protocol.PacketNumber } -func newPacketNumberGenerator(averagePeriod protocol.PacketNumber) *packetNumberGenerator { +func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator { return &packetNumberGenerator{ - next: 1, + next: initial, averagePeriod: averagePeriod, } } diff --git a/packet_number_generator_test.go b/packet_number_generator_test.go index 1f00cb75..ece74a96 100644 --- a/packet_number_generator_test.go +++ b/packet_number_generator_test.go @@ -12,7 +12,12 @@ var _ = Describe("Packet Number Generator", func() { var png packetNumberGenerator BeforeEach(func() { - png = *newPacketNumberGenerator(100) + png = *newPacketNumberGenerator(1, 100) + }) + + It("can be initialized to return any first packet number", func() { + png = *newPacketNumberGenerator(12345, 100) + Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345))) }) It("gets 1 as the first packet number", func() { diff --git a/packet_packer.go b/packet_packer.go index 1a637158..603f692f 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -32,9 +32,11 @@ type packetPacker struct { ackFrame *wire.AckFrame leastUnacked protocol.PacketNumber omitConnectionID bool + hasSentPacket bool // has the packetPacker already sent a packet } func newPacketPacker(connectionID protocol.ConnectionID, + initialPacketNumber protocol.PacketNumber, cryptoSetup handshake.CryptoSetup, streamFramer *streamFramer, perspective protocol.Perspective, @@ -46,7 +48,7 @@ func newPacketPacker(connectionID protocol.ConnectionID, perspective: perspective, version: version, streamFramer: streamFramer, - packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), + packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), } } @@ -116,7 +118,12 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - if p.streamFramer.HasCryptoStreamFrame() { + hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame() + // if this is the first packet to be send, make sure it contains stream data + if !p.hasSentPacket && !hasCryptoStreamFrame { + return nil, nil + } + if hasCryptoStreamFrame { return p.packCryptoPacket() } @@ -266,18 +273,21 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header pnum := p.packetNumberGenerator.Peek() packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked) - var isLongHeader bool - if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { - // TODO: set the Long Header type - packetNumberLen = protocol.PacketNumberLen4 - isLongHeader = true - } - header := &wire.Header{ ConnectionID: p.connectionID, PacketNumber: pnum, PacketNumberLen: packetNumberLen, - IsLongHeader: isLongHeader, + } + + if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { + header.PacketNumberLen = protocol.PacketNumberLen4 + header.IsLongHeader = true + if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient { + header.Type = protocol.PacketTypeInitial + // TODO(#886): add padding + } else { + header.Type = protocol.PacketTypeHandshake + } } if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure { @@ -292,7 +302,6 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header header.Version = p.version } } else { - header.Type = p.cryptoSetup.GetNextPacketType() if encLevel != protocol.EncryptionForwardSecure { header.Version = p.version } @@ -330,7 +339,7 @@ func (p *packetPacker) writeAndSealPacket( if num != header.PacketNumber { return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - + p.hasSentPacket = true return raw, nil } diff --git a/packet_packer_test.go b/packet_packer_test.go index 76e1e6ed..d2f8ab17 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -28,7 +28,6 @@ type mockCryptoSetup struct { divNonce []byte encLevelSeal protocol.EncryptionLevel encLevelSealCrypto protocol.EncryptionLevel - nextPacketType protocol.PacketType } var _ handshake.CryptoSetup = &mockCryptoSetup{} @@ -50,7 +49,6 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) } func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce } func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce } -func (m *mockCryptoSetup) GetNextPacketType() protocol.PacketType { return m.nextPacketType } var _ = Describe("Packet packer", func() { var ( @@ -69,13 +67,14 @@ var _ = Describe("Packet packer", func() { packer = &packetPacker{ cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, connectionID: 0x1337, - packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), + packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength), streamFramer: streamFramer, perspective: protocol.PerspectiveServer, } publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen packer.version = protocol.VersionWhatever + packer.hasSentPacket = true }) It("returns nil when no packet is queued", func() { @@ -191,13 +190,6 @@ var _ = Describe("Packet packer", func() { Expect(h.Version).To(Equal(versionIETFHeader)) }) - It("sets the packet type based on the state of the handshake", func() { - packer.cryptoSetup.(*mockCryptoSetup).nextPacketType = 5 - h := packer.getHeader(protocol.EncryptionSecure) - Expect(h.IsLongHeader).To(BeTrue()) - Expect(h.Type).To(Equal(protocol.PacketType(5))) - }) - It("uses the Short Header format for forward-secure packets", func() { h := packer.getHeader(protocol.EncryptionForwardSecure) Expect(h.IsLongHeader).To(BeFalse()) @@ -269,7 +261,7 @@ var _ = Describe("Packet packer", func() { Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber)) }) - It("packs a StopWaitingFrame first", func() { + It("packs a STOP_WAITING frame first", func() { packer.packetNumberGenerator.next = 15 swf := &wire.StopWaitingFrame{LeastUnacked: 10} packer.QueueControlFrame(&wire.RstStreamFrame{}) @@ -281,7 +273,7 @@ var _ = Describe("Packet packer", func() { Expect(p.frames[0]).To(Equal(swf)) }) - It("sets the LeastUnackedDelta length of a StopWaitingFrame", func() { + It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number packer.packetNumberGenerator.next = packetNumber swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} @@ -292,7 +284,7 @@ var _ = Describe("Packet packer", func() { Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) }) - It("does not pack a packet containing only a StopWaitingFrame", func() { + It("does not pack a packet containing only a STOP_WAITING frame", func() { swf := &wire.StopWaitingFrame{LeastUnacked: 10} packer.QueueControlFrame(swf) p, err := packer.PackPacket() @@ -307,6 +299,14 @@ var _ = Describe("Packet packer", func() { Expect(p).ToNot(BeNil()) }) + It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { + packer.hasSentPacket = false + packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + }) + It("packs many control frames into 1 packets", func() { f := &wire.AckFrame{LargestAcked: 1} b := &bytes.Buffer{} @@ -602,7 +602,7 @@ var _ = Describe("Packet packer", func() { }) }) - Context("Blocked frames", func() { + Context("BLOCKED frames", func() { It("queues a BLOCKED frame", func() { length := 100 streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}} @@ -750,7 +750,7 @@ var _ = Describe("Packet packer", func() { Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment")) }) - It("refuses to retransmit packets without a StopWaitingFrame", func() { + It("refuses to retransmit packets without a STOP_WAITING Frame", func() { packer.stopWaiting = nil _, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{ EncryptionLevel: protocol.EncryptionSecure, diff --git a/server.go b/server.go index 85303f3c..f2b82a3d 100644 --- a/server.go +++ b/server.go @@ -19,6 +19,7 @@ import ( // packetHandler handles packets type packetHandler interface { Session + getCryptoStream() cryptoStream handshakeStatus() <-chan handshakeEvent handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber @@ -33,6 +34,9 @@ type server struct { conn net.PacketConn + supportsTLS bool + serverTLS *serverTLS + certChain crypto.CertChain scfg *handshake.ServerConfig @@ -77,11 +81,21 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, if err != nil { return nil, err } + config = populateServerConfig(config) + + // check if any of the supported versions supports TLS + var supportsTLS bool + for _, v := range config.Versions { + if v.UsesTLS() { + supportsTLS = true + break + } + } s := &server{ conn: conn, tlsConf: tlsConf, - config: populateServerConfig(config), + config: config, certChain: certChain, scfg: scfg, sessions: map[protocol.ConnectionID]packetHandler{}, @@ -89,12 +103,47 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, sessionQueue: make(chan Session, 5), errorChan: make(chan struct{}), + supportsTLS: supportsTLS, + } + if supportsTLS { + if err := s.setupTLS(); err != nil { + return nil, err + } } go s.serve() utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } +func (s *server) setupTLS() error { + cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie) + if err != nil { + return err + } + serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf) + if err != nil { + return err + } + s.serverTLS = serverTLS + // handle TLS connection establishment statelessly + go func() { + for { + select { + case <-s.errorChan: + return + case sess := <-sessionChan: + // TODO: think about what to do with connection ID collisions + connID := sess.(*session).connectionID + s.sessionsMutex.Lock() + s.sessions[connID] = sess + s.sessionsMutex.Unlock() + s.runHandshakeAndSession(sess, connID) + } + } + }() + return nil +} + var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool { if cookie == nil { return false @@ -225,8 +274,16 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } hdr.Raw = packet[:len(packet)-r.Len()] + packetData := packet[len(packet)-r.Len():] connID := hdr.ConnectionID + if hdr.Type == protocol.PacketTypeInitial { + if s.supportsTLS { + go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData) + } + return nil + } + s.sessionsMutex.RLock() session, sessionKnown := s.sessions[connID] s.sessionsMutex.RUnlock() @@ -279,11 +336,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return err } } - // send an IETF draft style Version Negotiation Packet, if the client sent an unsupported version with an IETF draft style header - if hdr.Type == protocol.PacketTypeInitial && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - _, err := pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.config.Versions), remoteAddr) - return err - } if !sessionKnown { version := hdr.Version @@ -307,34 +359,38 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet s.sessions[connID] = session s.sessionsMutex.Unlock() - go func() { - // session.run() returns as soon as the session is closed - _ = session.run() - s.removeConnection(connID) - }() - - go func() { - for { - ev := <-session.handshakeStatus() - if ev.err != nil { - return - } - if ev.encLevel == protocol.EncryptionForwardSecure { - break - } - } - s.sessionQueue <- session - }() + s.runHandshakeAndSession(session, connID) } session.handlePacket(&receivedPacket{ remoteAddr: remoteAddr, header: hdr, - data: packet[len(packet)-r.Len():], + data: packetData, rcvTime: rcvTime, }) return nil } +func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) { + go func() { + _ = session.run() + // session.run() returns as soon as the session is closed + s.removeConnection(connID) + }() + + go func() { + for { + ev := <-session.handshakeStatus() + if ev.err != nil { + return + } + if ev.encLevel == protocol.EncryptionForwardSecure { + break + } + } + s.sessionQueue <- session + }() +} + func (s *server) removeConnection(id protocol.ConnectionID) { s.sessionsMutex.Lock() s.sessions[id] = nil diff --git a/server_test.go b/server_test.go index 189d3653..77fc88de 100644 --- a/server_test.go +++ b/server_test.go @@ -12,6 +12,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" @@ -67,6 +68,7 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not imple func (*mockSession) Context() context.Context { panic("not implemented") } func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } +func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") } var _ Session = &mockSession{} var _ NonFWSession = &mockSession{} @@ -96,7 +98,8 @@ var _ = Describe("Server", func() { ) BeforeEach(func() { - conn = &mockPacketConn{addr: &net.UDPAddr{}} + conn = newMockPacketConn() + conn.addr = &net.UDPAddr{} config = &Config{Versions: protocol.SupportedVersions} }) @@ -235,14 +238,14 @@ var _ = Describe("Server", func() { }) It("works if no quic.Config is given", func(done Done) { - ln, err := ListenAddr("127.0.0.1:0", nil, config) + ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), nil) Expect(err).ToNot(HaveOccurred()) Expect(ln.Close()).To(Succeed()) close(done) }, 1) It("closes properly", func() { - ln, err := ListenAddr("127.0.0.1:0", nil, config) + ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config) Expect(err).ToNot(HaveOccurred()) var returned bool @@ -409,7 +412,7 @@ var _ = Describe("Server", func() { } hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO - conn.dataToRead = b.Bytes() + conn.dataToRead <- b.Bytes() conn.dataReadFrom = udpAddr ln, err := Listen(conn, nil, config) Expect(err).ToNot(HaveOccurred()) @@ -432,7 +435,7 @@ var _ = Describe("Server", func() { }) It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() { - config.Versions = []protocol.VersionNumber{99} + config.Versions = []protocol.VersionNumber{99, protocol.VersionTLS} b := &bytes.Buffer{} hdr := wire.Header{ Type: protocol.PacketTypeInitial, @@ -441,11 +444,12 @@ var _ = Describe("Server", func() { PacketNumber: 0x55, Version: 0x1234, } - hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) + err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO - conn.dataToRead = b.Bytes() + conn.dataToRead <- b.Bytes() conn.dataReadFrom = udpAddr - ln, err := Listen(conn, nil, config) + ln, err := Listen(conn, testdata.GetTLSConfig(), config) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) @@ -466,9 +470,32 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) }) + It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() { + version := protocol.VersionNumber(99) + Expect(version.UsesTLS()).To(BeFalse()) + config.Versions = []protocol.VersionNumber{version} + b := &bytes.Buffer{} + hdr := wire.Header{ + Type: protocol.PacketTypeInitial, + IsLongHeader: true, + ConnectionID: 0x1337, + PacketNumber: 0x55, + Version: protocol.VersionTLS, + } + err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO + conn.dataToRead <- b.Bytes() + conn.dataReadFrom = udpAddr + ln, err := Listen(conn, testdata.GetTLSConfig(), config) + defer ln.Close() + Expect(err).ToNot(HaveOccurred()) + Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero()) + }) + It("sends a PublicReset for new connections that don't have the VersionFlag set", func() { conn.dataReadFrom = udpAddr - conn.dataToRead = []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01} + conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01} ln, err := Listen(conn, nil, config) Expect(err).ToNot(HaveOccurred()) go func() { diff --git a/server_tls.go b/server_tls.go new file mode 100644 index 00000000..860f3e42 --- /dev/null +++ b/server_tls.go @@ -0,0 +1,179 @@ +package quic + +import ( + "crypto/tls" + "fmt" + "net" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type nullAEAD struct { + aead crypto.AEAD +} + +var _ quicAEAD = &nullAEAD{} + +func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { + data, err := n.aead.Open(dst, src, packetNumber, associatedData) + return data, protocol.EncryptionUnencrypted, err +} + +type serverTLS struct { + conn net.PacketConn + config *Config + supportedVersions []protocol.VersionNumber + mintConf *mint.Config + cookieProtector mint.CookieProtector + params *handshake.TransportParameters + newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) + + sessionChan chan<- packetHandler +} + +func newServerTLS( + conn net.PacketConn, + config *Config, + cookieHandler *handshake.CookieHandler, + tlsConf *tls.Config, +) (*serverTLS, <-chan packetHandler, error) { + mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer) + if err != nil { + return nil, nil, err + } + mconf.RequireCookie = true + cs, err := mint.NewDefaultCookieProtector() + if err != nil { + return nil, nil, err + } + mconf.CookieProtector = cs + mconf.CookieHandler = cookieHandler + + sessionChan := make(chan packetHandler) + s := &serverTLS{ + conn: conn, + config: config, + supportedVersions: config.Versions, + mintConf: mconf, + sessionChan: sessionChan, + params: &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: protocol.MaxIncomingStreams, + IdleTimeout: config.IdleTimeout, + }, + } + s.newMintConn = s.newMintConnImpl + return s, sessionChan, nil +} + +func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) { + utils.Debugf("Received a Packet. Handling it statelessly.") + sess, err := s.handleInitialImpl(remoteAddr, hdr, data) + if err != nil { + utils.Errorf("Error occured handling initial packet: %s", err) + return + } + if sess == nil { // a stateless reset was done + return + } + s.sessionChan <- sess +} + +// will be set to s.newMintConn by the constructor +func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { + conn := mint.Server(bc, s.mintConf) + extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v) + if err := conn.SetExtensionHandler(extHandler); err != nil { + return nil, nil, err + } + tls := newMintController(bc, s.mintConf, protocol.PerspectiveServer) + tls.SetExtensionHandler(extHandler) + return tls, extHandler.GetPeerParams(), nil +} + +func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) { + // TODO: check length requirement + // check version, if not matching send VNP + if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) { + utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) + _, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.supportedVersions), remoteAddr) + return nil, err + } + + // unpack packet and check stream frame contents + version := hdr.Version + aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version) + if err != nil { + return nil, err + } + frame, err := unpackInitialPacket(aead, hdr, data, version) + if err != nil { + utils.Debugf("Error unpacking initial packet: %s", err) + return nil, nil + } + bc := handshake.NewCryptoStreamConn(remoteAddr) + bc.AddDataForReading(frame.Data) + tls, paramsChan, err := s.newMintConn(bc, hdr.Version) + if err != nil { + return nil, err + } + alert := tls.Handshake() + if alert == mint.AlertStatelessRetry { + // the HelloRetryRequest was written to the bufferConn + // Take that data and write send a Retry packet + replyHdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + ConnectionID: hdr.ConnectionID, // echo the client's connection ID + PacketNumber: hdr.PacketNumber, // echo the client's packet number + Version: version, + } + f := &wire.StreamFrame{ + StreamID: version.CryptoStreamID(), + Data: bc.GetDataForWriting(), + } + data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer) + if err != nil { + return nil, err + } + _, err = s.conn.WriteTo(data, remoteAddr) + return nil, err + } + if alert != mint.AlertNoAlert { + return nil, alert + } + if tls.State() != mint.StateServerNegotiated { + return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State()) + } + if alert := tls.Handshake(); alert != mint.AlertNoAlert { + return nil, alert + } + if tls.State() != mint.StateServerWaitFlight2 { + return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) + } + params := <-paramsChan + sess, err := newTLSServerSession( + &conn{pconn: s.conn, currentAddr: remoteAddr}, + hdr.ConnectionID, // TODO: we can use a server-chosen connection ID here + protocol.PacketNumber(1), // TODO: use a random packet number here + s.config, + tls, + bc, + aead, + ¶ms, + version, + ) + if err != nil { + return nil, err + } + cs := sess.getCryptoStream() + cs.SetReadOffset(frame.DataLen()) + bc.SetStream(cs) + return sess, nil +} diff --git a/server_tls_test.go b/server_tls_test.go new file mode 100644 index 00000000..5fa338e8 --- /dev/null +++ b/server_tls_test.go @@ -0,0 +1,116 @@ +package quic + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/mocks" + "github.com/lucas-clemente/quic-go/internal/mocks/handshake" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testdata" + "github.com/lucas-clemente/quic-go/internal/wire" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stateless TLS handling", func() { + var ( + conn *mockPacketConn + server *serverTLS + sessionChan <-chan packetHandler + mintTLS *mockhandshake.MockMintTLS + extHandler *mocks.MockTLSExtensionHandler + mintReply io.Writer + ) + + BeforeEach(func() { + mintTLS = mockhandshake.NewMockMintTLS(mockCtrl) + extHandler = mocks.NewMockTLSExtensionHandler(mockCtrl) + conn = newMockPacketConn() + config := &Config{ + Versions: []protocol.VersionNumber{protocol.VersionTLS}, + } + var err error + server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig()) + Expect(err).ToNot(HaveOccurred()) + server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { + mintReply = bc + return mintTLS, extHandler.GetPeerParams(), nil + } + }) + + getPacket := func(f wire.Frame) (*wire.Header, []byte) { + hdrBuf := &bytes.Buffer{} + hdr := &wire.Header{ + IsLongHeader: true, + PacketNumber: 1, + Version: protocol.VersionTLS, + } + err := hdr.Write(hdrBuf, protocol.PerspectiveClient, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + hdr.Raw = hdrBuf.Bytes() + aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, 0, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + buf := &bytes.Buffer{} + err = f.Write(buf, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + return hdr, aead.Seal(nil, buf.Bytes(), 1, hdr.Raw) + } + + It("sends a version negotiation packet if it doesn't support the version", func() { + server.HandleInitial(nil, &wire.Header{Version: 0x1337}, nil) + Expect(conn.dataWritten.Len()).ToNot(BeZero()) + hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionUnknown) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsVersionNegotiation).To(BeTrue()) + Expect(sessionChan).ToNot(Receive()) + }) + + It("ignores packets with invalid contents", func() { + hdr, data := getPacket(&wire.StreamFrame{StreamID: 10, Offset: 11, Data: []byte("foobar")}) + server.HandleInitial(nil, hdr, data) + Expect(conn.dataWritten.Len()).To(BeZero()) + Expect(sessionChan).ToNot(Receive()) + }) + + It("replies with a Retry packet, if a Cookie is required", func() { + extHandler.EXPECT().GetPeerParams() + mintTLS.EXPECT().Handshake().Return(mint.AlertStatelessRetry).Do(func() { + mintReply.Write([]byte("Retry with this Cookie")) + }) + hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")}) + server.HandleInitial(nil, hdr, data) + Expect(conn.dataWritten.Len()).ToNot(BeZero()) + hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(sessionChan).ToNot(Receive()) + }) + + It("replies with a Handshake packet and creates a session, if no Cookie is required", func() { + mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert).Do(func() { + mintReply.Write([]byte("Server Hello")) + }) + mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert) + mintTLS.EXPECT().State().Return(mint.StateServerNegotiated) + mintTLS.EXPECT().State().Return(mint.StateServerWaitFlight2) + paramsChan := make(chan handshake.TransportParameters, 1) + paramsChan <- handshake.TransportParameters{} + extHandler.EXPECT().GetPeerParams().Return(paramsChan) + hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")}) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.HandleInitial(nil, hdr, data) + // the Handshake packet is written by the session + Expect(conn.dataWritten.Len()).To(BeZero()) + close(done) + }() + Eventually(sessionChan).Should(Receive()) + Eventually(done).Should(BeClosed()) + }) +}) diff --git a/session.go b/session.go index 418d38ad..439edaa9 100644 --- a/session.go +++ b/session.go @@ -11,6 +11,7 @@ import ( "github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/congestion" + "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -60,7 +61,7 @@ type session struct { conn connection streamsMap *streamsMap - cryptoStream streamI + cryptoStream cryptoStream rttStats *congestion.RTTStats @@ -126,21 +127,48 @@ func newSession( conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, - sCfg *handshake.ServerConfig, + scfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, ) (packetHandler, error) { + paramsChan := make(chan handshake.TransportParameters) + aeadChanged := make(chan protocol.EncryptionLevel, 2) s := &session{ conn: conn, connectionID: connectionID, perspective: protocol.PerspectiveServer, version: v, config: config, + aeadChanged: aeadChanged, + paramsChan: paramsChan, } - return s, s.setup(sCfg, "", tlsConf, v, nil) + s.preSetup() + transportParams := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: protocol.MaxIncomingStreams, + IdleTimeout: s.config.IdleTimeout, + } + cs, err := newCryptoSetup( + s.cryptoStream, + s.connectionID, + s.conn.RemoteAddr(), + s.version, + scfg, + transportParams, + s.config.Versions, + s.config.AcceptCookie, + paramsChan, + aeadChanged, + ) + if err != nil { + return nil, err + } + s.cryptoSetup = cs + return s, s.postSetup(1) } -// declare this as a variable, such that we can it mock it in the tests +// declare this as a variable, so that we can it mock it in the tests var newClientSession = func( conn connection, hostname string, @@ -151,27 +179,130 @@ var newClientSession = func( initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton ) (packetHandler, error) { + paramsChan := make(chan handshake.TransportParameters) + aeadChanged := make(chan protocol.EncryptionLevel, 2) s := &session{ conn: conn, connectionID: connectionID, perspective: protocol.PerspectiveClient, version: v, config: config, + aeadChanged: aeadChanged, + paramsChan: paramsChan, } - return s, s.setup(nil, hostname, tlsConf, initialVersion, negotiatedVersions) + s.preSetup() + transportParams := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: protocol.MaxIncomingStreams, + IdleTimeout: s.config.IdleTimeout, + OmitConnectionID: s.config.RequestConnectionIDOmission, + } + cs, err := newCryptoSetupClient( + s.cryptoStream, + hostname, + s.connectionID, + s.version, + tlsConf, + transportParams, + paramsChan, + aeadChanged, + initialVersion, + negotiatedVersions, + ) + if err != nil { + return nil, err + } + s.cryptoSetup = cs + return s, s.postSetup(1) } -func (s *session) setup( - scfg *handshake.ServerConfig, - hostname string, - tlsConf *tls.Config, - initialVersion protocol.VersionNumber, - negotiatedVersions []protocol.VersionNumber, -) error { +func newTLSServerSession( + conn connection, + connectionID protocol.ConnectionID, + initialPacketNumber protocol.PacketNumber, + config *Config, + tls handshake.MintTLS, + cryptoStreamConn *handshake.CryptoStreamConn, + nullAEAD crypto.AEAD, + peerParams *handshake.TransportParameters, + v protocol.VersionNumber, +) (packetHandler, error) { aeadChanged := make(chan protocol.EncryptionLevel, 2) - paramsChan := make(chan handshake.TransportParameters) - s.aeadChanged = aeadChanged - s.paramsChan = paramsChan + s := &session{ + conn: conn, + config: config, + connectionID: connectionID, + perspective: protocol.PerspectiveServer, + version: v, + aeadChanged: aeadChanged, + } + s.preSetup() + s.cryptoSetup = handshake.NewCryptoSetupTLSServer( + tls, + cryptoStreamConn, + nullAEAD, + aeadChanged, + v, + ) + if err := s.postSetup(initialPacketNumber); err != nil { + return nil, err + } + s.peerParams = peerParams + s.processTransportParameters(peerParams) + s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} + return s, nil +} + +// declare this as a variable, such that we can it mock it in the tests +var newTLSClientSession = func( + conn connection, + hostname string, + v protocol.VersionNumber, + connectionID protocol.ConnectionID, + config *Config, + tls handshake.MintTLS, + paramsChan <-chan handshake.TransportParameters, + initialPacketNumber protocol.PacketNumber, +) (packetHandler, error) { + aeadChanged := make(chan protocol.EncryptionLevel, 2) + s := &session{ + conn: conn, + config: config, + connectionID: connectionID, + perspective: protocol.PerspectiveClient, + version: v, + aeadChanged: aeadChanged, + paramsChan: paramsChan, + } + s.preSetup() + tls.SetCryptoStream(s.cryptoStream) + cs, err := handshake.NewCryptoSetupTLSClient( + s.cryptoStream, + s.connectionID, + hostname, + aeadChanged, + tls, + v, + ) + if err != nil { + return nil, err + } + s.cryptoSetup = cs + return s, s.postSetup(initialPacketNumber) +} + +func (s *session) preSetup() { + s.rttStats = &congestion.RTTStats{} + s.connFlowController = flowcontrol.NewConnectionFlowController( + protocol.ReceiveConnectionFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), + s.rttStats, + ) + s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream) +} + +func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.handshakeChan = make(chan handshakeEvent, 3) s.handshakeCompleteChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) @@ -185,91 +316,14 @@ func (s *session) setup( s.lastNetworkActivityTime = now s.sessionCreationTime = now - s.rttStats = &congestion.RTTStats{} - transportParams := &handshake.TransportParameters{ - StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, - ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - MaxStreams: protocol.MaxIncomingStreams, - IdleTimeout: s.config.IdleTimeout, - } s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) - s.connFlowController = flowcontrol.NewConnectionFlowController( - protocol.ReceiveConnectionFlowControlWindow, - protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), - s.rttStats, - ) + s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) - s.cryptoStream = s.newStream(s.version.CryptoStreamID()) s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController) - var err error - if s.perspective == protocol.PerspectiveServer { - verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool { - return s.config.AcceptCookie(clientAddr, cookie) - } - if s.version.UsesTLS() { - s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer( - s.cryptoStream, - s.connectionID, - tlsConf, - s.conn.RemoteAddr(), - transportParams, - paramsChan, - aeadChanged, - verifySourceAddr, - s.config.Versions, - s.version, - ) - } else { - s.cryptoSetup, err = newCryptoSetup( - s.cryptoStream, - s.connectionID, - s.conn.RemoteAddr(), - s.version, - scfg, - transportParams, - s.config.Versions, - verifySourceAddr, - paramsChan, - aeadChanged, - ) - } - } else { - transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission - if s.version.UsesTLS() { - s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient( - s.cryptoStream, - s.connectionID, - hostname, - tlsConf, - transportParams, - paramsChan, - aeadChanged, - initialVersion, - s.config.Versions, - s.version, - ) - } else { - s.cryptoSetup, err = newCryptoSetupClient( - s.cryptoStream, - hostname, - s.connectionID, - s.version, - tlsConf, - transportParams, - paramsChan, - aeadChanged, - initialVersion, - negotiatedVersions, - ) - } - } - if err != nil { - return err - } - s.packer = newPacketPacker(s.connectionID, + initialPacketNumber, s.cryptoSetup, s.streamFramer, s.perspective, @@ -604,7 +658,7 @@ func (s *session) handleCloseError(closeErr closeError) error { s.cryptoStream.Cancel(quicErr) s.streamsMap.CloseWithError(quicErr) - if closeErr.err == errCloseSessionForNewVersion { + if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry { return nil } @@ -893,6 +947,10 @@ func (s *session) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } +func (s *session) getCryptoStream() cryptoStream { + return s.cryptoStream +} + func (s *session) GetVersion() protocol.VersionNumber { return s.version } diff --git a/session_test.go b/session_test.go index 3fff01fe..ba731ac6 100644 --- a/session_test.go +++ b/session_test.go @@ -468,9 +468,6 @@ var _ = Describe("Session", func() { }) It("handles CONNECTION_CLOSE frames", func() { - cryptoStream := mocks.NewMockStreamI(mockCtrl) - cryptoStream.EXPECT().Cancel(gomock.Any()) - sess.cryptoStream = cryptoStream done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -771,10 +768,15 @@ var _ = Describe("Session", func() { }) Context("sending packets", func() { + BeforeEach(func() { + sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends + }) + It("sends ACK frames", func() { packetNumber := protocol.PacketNumber(0x035e) - sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) - err := sess.sendPacket() + err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) + Expect(err).ToNot(HaveOccurred()) + err = sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) @@ -858,6 +860,7 @@ var _ = Describe("Session", func() { BeforeEach(func() { // a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet sess.packer.packetNumberGenerator.next = 0x1337 + 10 + sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends sph = newMockSentPacketHandler().(*mockSentPacketHandler) sess.sentPacketHandler = sph sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} @@ -981,6 +984,7 @@ var _ = Describe("Session", func() { }) It("retransmits RTO packets", func() { + sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends sess.sentPacketHandler.SetHandshakeComplete() n := protocol.PacketNumber(10) sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} @@ -1008,6 +1012,7 @@ var _ = Describe("Session", func() { Context("scheduling sending", func() { BeforeEach(func() { + sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends sess.processTransportParameters(&handshake.TransportParameters{ StreamFlowControlWindow: protocol.MaxByteCount, ConnectionFlowControlWindow: protocol.MaxByteCount, @@ -1291,6 +1296,7 @@ var _ = Describe("Session", func() { sess.handshakeComplete = true sess.config.KeepAlive = true sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) + sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends go sess.run() defer sess.Close(nil) var data []byte @@ -1551,7 +1557,10 @@ var _ = Describe("Client Session", func() { }) It("passes the diversification nonce to the cryptoSetup", func() { - go sess.run() + go func() { + defer GinkgoRecover() + sess.run() + }() hdr.PacketNumber = 5 hdr.DiversificationNonce = []byte("foobar") err := sess.handlePacketImpl(&receivedPacket{header: hdr}) diff --git a/stream.go b/stream.go index 806e7fc9..3fe9a8d1 100644 --- a/stream.go +++ b/stream.go @@ -32,6 +32,11 @@ type streamI interface { IsFlowControlBlocked() bool } +type cryptoStream interface { + streamI + SetReadOffset(protocol.ByteCount) +} + // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. @@ -481,3 +486,11 @@ func (s *stream) IsFlowControlBlocked() bool { func (s *stream) GetWindowUpdate() protocol.ByteCount { return s.flowController.GetWindowUpdate() } + +// SetReadOffset sets the read offset. +// It is only needed for the crypto stream. +// It must not be called concurrently with any other stream methods, especially Read and Write. +func (s *stream) SetReadOffset(offset protocol.ByteCount) { + s.readOffset = offset + s.frameQueue.readPosition = offset +} diff --git a/stream_test.go b/stream_test.go index 691f910a..0fec4f02 100644 --- a/stream_test.go +++ b/stream_test.go @@ -266,6 +266,12 @@ var _ = Describe("Stream", func() { Expect(onDataCalled).To(BeTrue()) }) + It("sets the read offset", func() { + str.SetReadOffset(0x42) + Expect(str.readOffset).To(Equal(protocol.ByteCount(0x42))) + Expect(str.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42))) + }) + Context("deadlines", func() { It("the deadline error has the right net.Error properties", func() { Expect(errDeadline.Temporary()).To(BeTrue()) diff --git a/streams_map.go b/streams_map.go index 8b71903a..e162205c 100644 --- a/streams_map.go +++ b/streams_map.go @@ -325,4 +325,5 @@ func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) { m.mutex.Lock() defer m.mutex.Unlock() m.maxOutgoingStreams = limit + m.openStreamOrErrCond.Broadcast() }