From f0c647004fe536aced0b57c1a57d08de036b9a1e Mon Sep 17 00:00:00 2001 From: Toby Date: Sat, 11 Apr 2020 15:47:53 -0700 Subject: [PATCH] forwarder: better pipe error handling --- internal/forwarder/client.go | 17 ++++++++++++----- internal/forwarder/server.go | 17 ++++++++++++----- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/internal/forwarder/client.go b/internal/forwarder/client.go index 7813023..0c1847f 100644 --- a/internal/forwarder/client.go +++ b/internal/forwarder/client.go @@ -187,14 +187,21 @@ func (c *QUICClient) handleConn(conn net.Conn) { return } defer stream.Close() - // From TCP to QUIC + // Pipes + errChan := make(chan error, 2) go func() { - _ = utils.Pipe(conn, stream, &c.outboundBytes) + // TCP to QUIC + errChan <- utils.Pipe(conn, stream, &c.outboundBytes) _ = conn.Close() _ = stream.Close() }() - // From QUIC to TCP - err = utils.Pipe(stream, conn, &c.inboundBytes) - // Closed + go func() { + // QUIC to TCP + errChan <- utils.Pipe(stream, conn, &c.inboundBytes) + _ = conn.Close() + _ = stream.Close() + }() + // We only need the first error + err = <-errChan c.onTCPConnectionClosed(conn.RemoteAddr(), err) } diff --git a/internal/forwarder/server.go b/internal/forwarder/server.go index 9d7049c..853d082 100644 --- a/internal/forwarder/server.go +++ b/internal/forwarder/server.go @@ -160,14 +160,21 @@ func (s *QUICServer) handleStream(addr net.Addr, name string, stream quic.Stream return } defer tcpConn.Close() - // From TCP to QUIC + // Pipes + errChan := make(chan error, 2) go func() { - _ = utils.Pipe(tcpConn, stream, &s.outboundBytes) + // TCP to QUIC + errChan <- utils.Pipe(tcpConn, stream, &s.outboundBytes) _ = tcpConn.Close() _ = stream.Close() }() - // From QUIC to TCP - err = utils.Pipe(stream, tcpConn, &s.inboundBytes) - // Closed + go func() { + // QUIC to TCP + errChan <- utils.Pipe(stream, tcpConn, &s.inboundBytes) + _ = tcpConn.Close() + _ = stream.Close() + }() + // We only need the first error + err = <-errChan s.onClientStreamClosed(addr, name, int(stream.StreamID()), err) }