expose the tls.ConnectionState

This commit is contained in:
Marten Seemann 2019-03-25 11:49:51 +01:00
parent 3f4b6d1df8
commit 09574a6653
10 changed files with 116 additions and 21 deletions

View file

@ -1,6 +1,7 @@
run: run:
skip-files: skip-files:
- h2quic/response_writer_closenotifier.go - h2quic/response_writer_closenotifier.go
- internal/handshake/unsafe_test.go
linters-settings: linters-settings:
misspell: misspell:

View file

@ -82,7 +82,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
func (s *mockSession) Context() context.Context { func (s *mockSession) Context() context.Context {
return s.ctx return s.ctx
} }
func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") } func (s *mockSession) ConnectionState() tls.ConnectionState { panic("not implemented") }
func (s *mockSession) AcceptUniStream() (quic.ReceiveStream, error) { panic("not implemented") } func (s *mockSession) AcceptUniStream() (quic.ReceiveStream, error) { panic("not implemented") }
func (s *mockSession) OpenUniStream() (quic.SendStream, error) { panic("not implemented") } func (s *mockSession) OpenUniStream() (quic.SendStream, error) { panic("not implemented") }
func (s *mockSession) OpenUniStreamSync() (quic.SendStream, error) { panic("not implemented") } func (s *mockSession) OpenUniStreamSync() (quic.SendStream, error) { panic("not implemented") }

View file

@ -2,11 +2,11 @@ package quic
import ( import (
"context" "context"
"crypto/tls"
"io" "io"
"net" "net"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
@ -22,9 +22,6 @@ type Cookie struct {
SentTime time.Time SentTime time.Time
} }
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
// An ErrorCode is an application-defined error code. // An ErrorCode is an application-defined error code.
type ErrorCode = protocol.ApplicationErrorCode type ErrorCode = protocol.ApplicationErrorCode
@ -164,7 +161,7 @@ type Session interface {
Context() context.Context Context() context.Context
// ConnectionState returns basic details about the QUIC connection. // ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon. // Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState ConnectionState() tls.ConnectionState
} }
// Config contains all configuration data needed for a QUIC server or client. // Config contains all configuration data needed for a QUIC server or client.

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qerr"
@ -541,13 +542,14 @@ func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error)
} }
} }
func (h *cryptoSetup) ConnectionState() ConnectionState { func (h *cryptoSetup) ConnectionState() tls.ConnectionState {
connState := h.conn.ConnectionState() cs := h.conn.ConnectionState()
return ConnectionState{ // h.conn is a qtls.Conn, which returns a qtls.ConnectionState.
HandshakeComplete: connState.HandshakeComplete, // qtls.ConnectionState is identical to the tls.ConnectionState.
ServerName: connState.ServerName, // It contains an unexported field which is used ExportKeyingMaterial().
PeerCertificates: connState.PeerCertificates, // The only way to return a tls.ConnectionState is to use unsafe.
} // In unsafe.go we check that the two objects are actually identical.
return *(*tls.ConnectionState)(unsafe.Pointer(&cs))
} }
func (h *cryptoSetup) tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config { func (h *cryptoSetup) tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {

View file

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"crypto/tls"
"crypto/x509" "crypto/x509"
"io" "io"
@ -35,7 +36,7 @@ type CryptoSetup interface {
ChangeConnectionID(protocol.ConnectionID) error ChangeConnectionID(protocol.ConnectionID) error
HandleMessage([]byte, protocol.EncryptionLevel) bool HandleMessage([]byte, protocol.EncryptionLevel) bool
ConnectionState() ConnectionState ConnectionState() tls.ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)

View file

@ -0,0 +1,33 @@
package handshake
// This package uses unsafe to convert between qtls.ConnectionState and tls.ConnectionState.
// We check in init() that this conversion actually is safe.
import (
"crypto/tls"
"reflect"
"github.com/marten-seemann/qtls"
)
func init() {
if !structsEqual(&tls.ConnectionState{}, &qtls.ConnectionState{}) {
panic("qtls.ConnectionState not compatible with tls.ConnectionState")
}
}
func structsEqual(a, b interface{}) bool {
sa := reflect.ValueOf(a).Elem()
sb := reflect.ValueOf(b).Elem()
if sa.NumField() != sb.NumField() {
return false
}
for i := 0; i < sa.NumField(); i++ {
fa := sa.Type().Field(i)
fb := sb.Type().Field(i)
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
return false
}
}
return true
}

View file

@ -0,0 +1,60 @@
package handshake
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type target struct {
Name string
Version string
callback func(label string, length int) error
}
type renamedField struct {
NewName string
Version string
callback func(label string, length int) error
}
type renamedPrivateField struct {
Name string
Version string
cb func(label string, length int) error
}
type additionalField struct {
Name string
Version string
callback func(label string, length int) error
secret []byte
}
type interchangedFields struct {
Version string
Name string
callback func(label string, length int) error
}
type renamedCallbackFunctionParams struct { // should be equivalent
Name string
Version string
callback func(newLabel string, length int) error
}
var _ = Describe("Unsafe checks", func() {
It("detects if an unsafe conversion is safe", func() {
Expect(structsEqual(&target{}, &target{})).To(BeTrue())
Expect(structsEqual(&target{}, &renamedField{})).To(BeFalse())
Expect(structsEqual(&target{}, &renamedPrivateField{})).To(BeFalse())
Expect(structsEqual(&target{}, &additionalField{})).To(BeFalse())
Expect(structsEqual(&target{}, &interchangedFields{})).To(BeFalse())
Expect(structsEqual(&target{}, &renamedCallbackFunctionParams{})).To(BeTrue())
})
})

View file

@ -5,6 +5,7 @@
package mocks package mocks
import ( import (
tls "crypto/tls"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@ -64,10 +65,10 @@ func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call {
} }
// ConnectionState mocks base method // ConnectionState mocks base method
func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState { func (m *MockCryptoSetup) ConnectionState() tls.ConnectionState {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConnectionState") ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(handshake.ConnectionState) ret0, _ := ret[0].(tls.ConnectionState)
return ret0 return ret0
} }

View file

@ -6,11 +6,11 @@ package quic
import ( import (
context "context" context "context"
tls "crypto/tls"
net "net" net "net"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
protocol "github.com/lucas-clemente/quic-go/internal/protocol" protocol "github.com/lucas-clemente/quic-go/internal/protocol"
) )
@ -96,10 +96,10 @@ func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *g
} }
// ConnectionState mocks base method // ConnectionState mocks base method
func (m *MockQuicSession) ConnectionState() handshake.ConnectionState { func (m *MockQuicSession) ConnectionState() tls.ConnectionState {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConnectionState") ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(handshake.ConnectionState) ret0, _ := ret[0].(tls.ConnectionState)
return ret0 return ret0
} }

View file

@ -50,7 +50,7 @@ type cryptoStreamHandler interface {
RunHandshake() error RunHandshake() error
ChangeConnectionID(protocol.ConnectionID) error ChangeConnectionID(protocol.ConnectionID) error
io.Closer io.Closer
ConnectionState() handshake.ConnectionState ConnectionState() tls.ConnectionState
} }
type receivedPacket struct { type receivedPacket struct {
@ -437,7 +437,7 @@ func (s *session) Context() context.Context {
return s.ctx return s.ctx
} }
func (s *session) ConnectionState() ConnectionState { func (s *session) ConnectionState() tls.ConnectionState {
return s.cryptoStreamHandler.ConnectionState() return s.cryptoStreamHandler.ConnectionState()
} }