From 09574a6653da4d868e22e2d6eabdbcfa9ee6284d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 25 Mar 2019 11:49:51 +0100 Subject: [PATCH] expose the tls.ConnectionState --- .golangci.yml | 1 + h2quic/server_test.go | 2 +- interface.go | 7 +--- internal/handshake/crypto_setup.go | 16 ++++---- internal/handshake/interface.go | 3 +- internal/handshake/unsafe.go | 33 ++++++++++++++++ internal/handshake/unsafe_test.go | 60 ++++++++++++++++++++++++++++++ internal/mocks/crypto_setup.go | 5 ++- mock_quic_session_test.go | 6 +-- session.go | 4 +- 10 files changed, 116 insertions(+), 21 deletions(-) create mode 100644 internal/handshake/unsafe.go create mode 100644 internal/handshake/unsafe_test.go diff --git a/.golangci.yml b/.golangci.yml index 6d90debc..eb9de2f3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,6 +1,7 @@ run: skip-files: - h2quic/response_writer_closenotifier.go + - internal/handshake/unsafe_test.go linters-settings: misspell: diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 0de87bc1..01659734 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -82,7 +82,7 @@ func (s *mockSession) RemoteAddr() net.Addr { func (s *mockSession) Context() context.Context { return s.ctx } -func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") } +func (s *mockSession) ConnectionState() tls.ConnectionState { panic("not implemented") } func (s *mockSession) AcceptUniStream() (quic.ReceiveStream, error) { panic("not implemented") } func (s *mockSession) OpenUniStream() (quic.SendStream, error) { panic("not implemented") } func (s *mockSession) OpenUniStreamSync() (quic.SendStream, error) { panic("not implemented") } diff --git a/interface.go b/interface.go index 6402d076..95842073 100644 --- a/interface.go +++ b/interface.go @@ -2,11 +2,11 @@ package quic import ( "context" + "crypto/tls" "io" "net" "time" - "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" ) @@ -22,9 +22,6 @@ type Cookie struct { SentTime time.Time } -// ConnectionState records basic details about the QUIC connection. -type ConnectionState = handshake.ConnectionState - // An ErrorCode is an application-defined error code. type ErrorCode = protocol.ApplicationErrorCode @@ -164,7 +161,7 @@ type Session interface { Context() context.Context // ConnectionState returns basic details about the QUIC connection. // Warning: This API should not be considered stable and might change soon. - ConnectionState() ConnectionState + ConnectionState() tls.ConnectionState } // Config contains all configuration data needed for a QUIC server or client. diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 6dbaa562..348d1c72 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "unsafe" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" @@ -541,13 +542,14 @@ func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) } } -func (h *cryptoSetup) ConnectionState() ConnectionState { - connState := h.conn.ConnectionState() - return ConnectionState{ - HandshakeComplete: connState.HandshakeComplete, - ServerName: connState.ServerName, - PeerCertificates: connState.PeerCertificates, - } +func (h *cryptoSetup) ConnectionState() tls.ConnectionState { + cs := h.conn.ConnectionState() + // h.conn is a qtls.Conn, which returns a qtls.ConnectionState. + // qtls.ConnectionState is identical to the tls.ConnectionState. + // It contains an unexported field which is used ExportKeyingMaterial(). + // The only way to return a tls.ConnectionState is to use unsafe. + // In unsafe.go we check that the two objects are actually identical. + return *(*tls.ConnectionState)(unsafe.Pointer(&cs)) } func (h *cryptoSetup) tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config { diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 2c6d32b4..22522513 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -1,6 +1,7 @@ package handshake import ( + "crypto/tls" "crypto/x509" "io" @@ -35,7 +36,7 @@ type CryptoSetup interface { ChangeConnectionID(protocol.ConnectionID) error HandleMessage([]byte, protocol.EncryptionLevel) bool - ConnectionState() ConnectionState + ConnectionState() tls.ConnectionState GetSealer() (protocol.EncryptionLevel, Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) diff --git a/internal/handshake/unsafe.go b/internal/handshake/unsafe.go new file mode 100644 index 00000000..df9c0db8 --- /dev/null +++ b/internal/handshake/unsafe.go @@ -0,0 +1,33 @@ +package handshake + +// This package uses unsafe to convert between qtls.ConnectionState and tls.ConnectionState. +// We check in init() that this conversion actually is safe. + +import ( + "crypto/tls" + "reflect" + + "github.com/marten-seemann/qtls" +) + +func init() { + if !structsEqual(&tls.ConnectionState{}, &qtls.ConnectionState{}) { + panic("qtls.ConnectionState not compatible with tls.ConnectionState") + } +} + +func structsEqual(a, b interface{}) bool { + sa := reflect.ValueOf(a).Elem() + sb := reflect.ValueOf(b).Elem() + if sa.NumField() != sb.NumField() { + return false + } + for i := 0; i < sa.NumField(); i++ { + fa := sa.Type().Field(i) + fb := sb.Type().Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + return false + } + } + return true +} diff --git a/internal/handshake/unsafe_test.go b/internal/handshake/unsafe_test.go new file mode 100644 index 00000000..360a07bb --- /dev/null +++ b/internal/handshake/unsafe_test.go @@ -0,0 +1,60 @@ +package handshake + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type target struct { + Name string + Version string + + callback func(label string, length int) error +} + +type renamedField struct { + NewName string + Version string + + callback func(label string, length int) error +} + +type renamedPrivateField struct { + Name string + Version string + + cb func(label string, length int) error +} + +type additionalField struct { + Name string + Version string + + callback func(label string, length int) error + secret []byte +} + +type interchangedFields struct { + Version string + Name string + + callback func(label string, length int) error +} + +type renamedCallbackFunctionParams struct { // should be equivalent + Name string + Version string + + callback func(newLabel string, length int) error +} + +var _ = Describe("Unsafe checks", func() { + It("detects if an unsafe conversion is safe", func() { + Expect(structsEqual(&target{}, &target{})).To(BeTrue()) + Expect(structsEqual(&target{}, &renamedField{})).To(BeFalse()) + Expect(structsEqual(&target{}, &renamedPrivateField{})).To(BeFalse()) + Expect(structsEqual(&target{}, &additionalField{})).To(BeFalse()) + Expect(structsEqual(&target{}, &interchangedFields{})).To(BeFalse()) + Expect(structsEqual(&target{}, &renamedCallbackFunctionParams{})).To(BeTrue()) + }) +}) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index cc5bd66c..01e9bea7 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -5,6 +5,7 @@ package mocks import ( + tls "crypto/tls" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -64,10 +65,10 @@ func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call { } // ConnectionState mocks base method -func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState { +func (m *MockCryptoSetup) ConnectionState() tls.ConnectionState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(handshake.ConnectionState) + ret0, _ := ret[0].(tls.ConnectionState) return ret0 } diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 90f1e6d9..cb09527a 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -6,11 +6,11 @@ package quic import ( context "context" + tls "crypto/tls" net "net" reflect "reflect" gomock "github.com/golang/mock/gomock" - handshake "github.com/lucas-clemente/quic-go/internal/handshake" protocol "github.com/lucas-clemente/quic-go/internal/protocol" ) @@ -96,10 +96,10 @@ func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *g } // ConnectionState mocks base method -func (m *MockQuicSession) ConnectionState() handshake.ConnectionState { +func (m *MockQuicSession) ConnectionState() tls.ConnectionState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(handshake.ConnectionState) + ret0, _ := ret[0].(tls.ConnectionState) return ret0 } diff --git a/session.go b/session.go index 82b85769..94d92ab3 100644 --- a/session.go +++ b/session.go @@ -50,7 +50,7 @@ type cryptoStreamHandler interface { RunHandshake() error ChangeConnectionID(protocol.ConnectionID) error io.Closer - ConnectionState() handshake.ConnectionState + ConnectionState() tls.ConnectionState } type receivedPacket struct { @@ -437,7 +437,7 @@ func (s *session) Context() context.Context { return s.ctx } -func (s *session) ConnectionState() ConnectionState { +func (s *session) ConnectionState() tls.ConnectionState { return s.cryptoStreamHandler.ConnectionState() }