Fix missing handshake timeout

This commit is contained in:
世界 2023-12-19 20:00:00 +08:00
parent 1cbd1ab6a3
commit bb6a56560a
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 16 additions and 8 deletions

View file

@ -113,7 +113,7 @@ func (c *Client) openStream(ctx context.Context) (net.Conn, error) {
if err != nil { if err != nil {
continue continue
} }
stream, err = session.Open() stream, err = session.OpenContext(ctx)
if err != nil { if err != nil {
continue continue
} }
@ -168,6 +168,8 @@ func (c *Client) offer(ctx context.Context) (abstractSession, error) {
} }
func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
ctx, cancel := context.WithTimeout(ctx, TCPTimeout)
defer cancel()
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination) conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination)
if err != nil { if err != nil {
return nil, err return nil, err
@ -192,7 +194,7 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
return nil, err return nil, err
} }
if c.brutal.Enabled { if c.brutal.Enabled {
err = c.brutalExchange(conn, session) err = c.brutalExchange(ctx, conn, session)
if err != nil { if err != nil {
conn.Close() conn.Close()
session.Close() session.Close()
@ -203,8 +205,8 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
return session, nil return session, nil
} }
func (c *Client) brutalExchange(sessionConn net.Conn, session abstractSession) error { func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error {
stream, err := session.Open() stream, err := session.OpenContext(ctx)
if err != nil { if err != nil {
return err return err
} }

View file

@ -64,7 +64,7 @@ func (s *h2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http
} }
} }
func (s *h2MuxServerSession) Open() (net.Conn, error) { func (s *h2MuxServerSession) OpenContext(ctx context.Context) (net.Conn, error) {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }
@ -197,13 +197,14 @@ func (s *h2MuxClientSession) MarkDead(conn *http2.ClientConn) {
s.Close() s.Close()
} }
func (s *h2MuxClientSession) Open() (net.Conn, error) { func (s *h2MuxClientSession) OpenContext(ctx context.Context) (net.Conn, error) {
pipeInReader, pipeInWriter := io.Pipe() pipeInReader, pipeInWriter := io.Pipe()
request := &http.Request{ request := &http.Request{
Method: http.MethodConnect, Method: http.MethodConnect,
Body: pipeInReader, Body: pipeInReader,
URL: &url.URL{Scheme: "https", Host: "localhost"}, URL: &url.URL{Scheme: "https", Host: "localhost"},
} }
request = request.WithContext(ctx)
conn := newLateHTTPConn(pipeInWriter) conn := newLateHTTPConn(pipeInWriter)
go func() { go func() {
response, err := s.transport.RoundTrip(request) response, err := s.transport.RoundTrip(request)

View file

@ -1,6 +1,7 @@
package mux package mux
import ( import (
"context"
"io" "io"
"net" "net"
"reflect" "reflect"
@ -12,7 +13,7 @@ import (
) )
type abstractSession interface { type abstractSession interface {
Open() (net.Conn, error) OpenContext(ctx context.Context) (net.Conn, error)
Accept() (net.Conn, error) Accept() (net.Conn, error)
NumStreams() int NumStreams() int
Close() error Close() error
@ -80,7 +81,7 @@ type smuxSession struct {
*smux.Session *smux.Session
} }
func (s *smuxSession) Open() (net.Conn, error) { func (s *smuxSession) OpenContext(context.Context) (net.Conn, error) {
return s.OpenStream() return s.OpenStream()
} }
@ -96,6 +97,10 @@ type yamuxSession struct {
*yamux.Session *yamux.Session
} }
func (y *yamuxSession) OpenContext(context.Context) (net.Conn, error) {
return y.OpenStream()
}
func (y *yamuxSession) CanTakeNewRequest() bool { func (y *yamuxSession) CanTakeNewRequest() bool {
return true return true
} }