pass a conn to qtls that returns the remote address

This commit is contained in:
Marten Seemann 2019-03-25 16:59:35 +01:00
parent b2723d6d13
commit da4b3e3176
4 changed files with 37 additions and 2 deletions

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"unsafe" "unsafe"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
@ -112,6 +113,7 @@ func NewCryptoSetupClient(
handshakeStream io.Writer, handshakeStream io.Writer,
oneRTTStream io.Writer, oneRTTStream io.Writer,
connID protocol.ConnectionID, connID protocol.ConnectionID,
remoteAddr net.Addr,
tp *TransportParameters, tp *TransportParameters,
handleParams func([]byte), handleParams func([]byte),
tlsConf *tls.Config, tlsConf *tls.Config,
@ -131,7 +133,7 @@ func NewCryptoSetupClient(
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cs.conn = qtls.Client(nil, cs.tlsConf) cs.conn = qtls.Client(newConn(remoteAddr), cs.tlsConf)
return cs, clientHelloWritten, nil return cs, clientHelloWritten, nil
} }
@ -141,6 +143,7 @@ func NewCryptoSetupServer(
handshakeStream io.Writer, handshakeStream io.Writer,
oneRTTStream io.Writer, oneRTTStream io.Writer,
connID protocol.ConnectionID, connID protocol.ConnectionID,
remoteAddr net.Addr,
tp *TransportParameters, tp *TransportParameters,
handleParams func([]byte), handleParams func([]byte),
tlsConf *tls.Config, tlsConf *tls.Config,
@ -160,7 +163,7 @@ func NewCryptoSetupServer(
if err != nil { if err != nil {
return nil, err return nil, err
} }
cs.conn = qtls.Server(nil, cs.tlsConf) cs.conn = qtls.Server(newConn(remoteAddr), cs.tlsConf)
return cs, nil return cs, nil
} }

View file

@ -84,6 +84,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&bytes.Buffer{}, &bytes.Buffer{},
ioutil.Discard, ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
tlsConf, tlsConf,
@ -111,6 +112,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
@ -144,6 +146,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
@ -171,6 +174,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
@ -249,6 +253,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cHandshakeStream, cHandshakeStream,
ioutil.Discard, ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
clientConf, clientConf,
@ -263,6 +268,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil,
&TransportParameters{StatelessResetToken: &token}, &TransportParameters{StatelessResetToken: &token},
func([]byte) {}, func([]byte) {},
serverConf, serverConf,
@ -304,6 +310,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cHandshakeStream, cHandshakeStream,
ioutil.Discard, ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
&tls.Config{InsecureSkipVerify: true}, &tls.Config{InsecureSkipVerify: true},
@ -340,6 +347,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cHandshakeStream, cHandshakeStream,
ioutil.Discard, ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil,
cTransportParameters, cTransportParameters,
func(p []byte) { sTransportParametersRcvd = p }, func(p []byte) { sTransportParametersRcvd = p },
clientConf, clientConf,
@ -358,6 +366,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil,
sTransportParameters, sTransportParameters,
func(p []byte) { cTransportParametersRcvd = p }, func(p []byte) { cTransportParametersRcvd = p },
testdata.GetTLSConfig(), testdata.GetTLSConfig(),

View file

@ -2,11 +2,32 @@ package handshake
import ( import (
"crypto/tls" "crypto/tls"
"net"
"time"
"unsafe" "unsafe"
"github.com/marten-seemann/qtls" "github.com/marten-seemann/qtls"
) )
type conn struct {
remoteAddr net.Addr
}
func newConn(remote net.Addr) net.Conn {
return &conn{remoteAddr: remote}
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return nil }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
type clientSessionCache struct { type clientSessionCache struct {
tls.ClientSessionCache tls.ClientSessionCache
} }

View file

@ -196,6 +196,7 @@ var newSession = func(
handshakeStream, handshakeStream,
oneRTTStream, oneRTTStream,
clientDestConnID, clientDestConnID,
conn.RemoteAddr(),
params, params,
s.processTransportParameters, s.processTransportParameters,
tlsConf, tlsConf,
@ -263,6 +264,7 @@ var newClientSession = func(
handshakeStream, handshakeStream,
oneRTTStream, oneRTTStream,
s.destConnID, s.destConnID,
conn.RemoteAddr(),
params, params,
s.processTransportParameters, s.processTransportParameters,
tlsConf, tlsConf,