diff --git a/client.go b/client.go index 35459a5..86b307b 100644 --- a/client.go +++ b/client.go @@ -111,7 +111,7 @@ func (c *Client) openStream(ctx context.Context) (net.Conn, error) { if err != nil { continue } - stream, err = session.Open() + stream, err = session.OpenContext(ctx) if err != nil { continue } @@ -166,6 +166,8 @@ func (c *Client) offer(ctx context.Context) (abstractSession, error) { } func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { + ctx, cancel := context.WithTimeout(ctx, TCPTimeout) + defer cancel() conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination) if err != nil { return nil, err @@ -190,7 +192,7 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { return nil, err } if c.brutal.Enabled { - err = c.brutalExchange(conn, session) + err = c.brutalExchange(ctx, conn, session) if err != nil { conn.Close() session.Close() @@ -201,8 +203,8 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { return session, nil } -func (c *Client) brutalExchange(sessionConn net.Conn, session abstractSession) error { - stream, err := session.Open() +func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error { + stream, err := session.OpenContext(ctx) if err != nil { return err } diff --git a/h2mux.go b/h2mux.go index a67ef70..cb827a9 100644 --- a/h2mux.go +++ b/h2mux.go @@ -64,7 +64,7 @@ func (s *h2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http } } -func (s *h2MuxServerSession) Open() (net.Conn, error) { +func (s *h2MuxServerSession) OpenContext(ctx context.Context) (net.Conn, error) { return nil, os.ErrInvalid } @@ -197,13 +197,14 @@ func (s *h2MuxClientSession) MarkDead(conn *http2.ClientConn) { s.Close() } -func (s *h2MuxClientSession) Open() (net.Conn, error) { +func (s *h2MuxClientSession) OpenContext(ctx context.Context) (net.Conn, error) { pipeInReader, pipeInWriter := io.Pipe() request := &http.Request{ Method: http.MethodConnect, Body: pipeInReader, URL: &url.URL{Scheme: "https", Host: "localhost"}, } + request = request.WithContext(ctx) conn := newLateHTTPConn(pipeInWriter) go func() { response, err := s.transport.RoundTrip(request) diff --git a/session.go b/session.go index 2dc37b3..524f573 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package mux import ( + "context" "io" "net" "reflect" @@ -12,7 +13,7 @@ import ( ) type abstractSession interface { - Open() (net.Conn, error) + OpenContext(ctx context.Context) (net.Conn, error) Accept() (net.Conn, error) NumStreams() int Close() error @@ -80,7 +81,7 @@ type smuxSession struct { *smux.Session } -func (s *smuxSession) Open() (net.Conn, error) { +func (s *smuxSession) OpenContext(context.Context) (net.Conn, error) { return s.OpenStream() } @@ -96,6 +97,10 @@ type yamuxSession struct { *yamux.Session } +func (y *yamuxSession) OpenContext(context.Context) (net.Conn, error) { + return y.OpenStream() +} + func (y *yamuxSession) CanTakeNewRequest() bool { return true }