diff --git a/common/tls/listener.go b/common/tls/listener.go index 5fa97ed..872e0f2 100644 --- a/common/tls/listener.go +++ b/common/tls/listener.go @@ -1,7 +1,9 @@ package tls import ( + "context" "net" + "sync" ) type Listener struct { @@ -10,10 +12,10 @@ type Listener struct { } func NewListener(inner net.Listener, config ServerConfig) net.Listener { - l := new(Listener) - l.Listener = inner - l.config = config - return l + return &Listener{ + Listener: inner, + config: config, + } } func (l *Listener) Accept() (net.Conn, error) { @@ -21,5 +23,69 @@ func (l *Listener) Accept() (net.Conn, error) { if err != nil { return nil, err } - return l.config.Server(conn) + return NewLazyConn(conn, l.config), nil +} + +type LazyConn struct { + net.Conn + tlsConfig ServerConfig + access sync.Mutex + needHandshake bool +} + +func NewLazyConn(conn net.Conn, config ServerConfig) *LazyConn { + return &LazyConn{ + Conn: conn, + tlsConfig: config, + needHandshake: true, + } +} + +func (c *LazyConn) HandshakeContext(ctx context.Context) error { + if !c.needHandshake { + return nil + } + c.access.Lock() + defer c.access.Unlock() + if c.needHandshake { + tlsConn, err := ServerHandshake(ctx, c.Conn, c.tlsConfig) + if err != nil { + return err + } + c.Conn = tlsConn + c.needHandshake = false + } + return nil +} + +func (c *LazyConn) Read(p []byte) (n int, err error) { + err = c.HandshakeContext(context.Background()) + if err != nil { + return + } + return c.Conn.Read(p) +} + +func (c *LazyConn) Write(p []byte) (n int, err error) { + err = c.HandshakeContext(context.Background()) + if err != nil { + return + } + return c.Conn.Write(p) +} + +func (c *LazyConn) NeedHandshake() bool { + return c.needHandshake +} + +func (c *LazyConn) ReaderReplaceable() bool { + return !c.needHandshake +} + +func (c *LazyConn) WriterReplaceable() bool { + return !c.needHandshake +} + +func (c *LazyConn) Upstream() any { + return c.Conn }