diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 8a156ad..6f2271b 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -4,6 +4,7 @@ import ( std_bufio "bufio" "context" "encoding/base64" + "io" "net" "net/http" "strings" @@ -28,7 +29,6 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read if err != nil { return E.Cause(err, "read http request") } - if authenticator != nil { var ( username string @@ -72,11 +72,15 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read } if request.Method == "CONNECT" { - portStr := request.URL.Port() - if portStr == "" { - portStr = "80" + destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port()) + if destination.Port == 0 { + switch request.URL.Scheme { + case "https", "wss": + destination.Port = 443 + default: + destination.Port = 80 + } } - destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr) _, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n"))) if err != nil { return E.Cause(err, "write http response") @@ -96,74 +100,119 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read requestConn = conn } return handler.NewConnection(ctx, requestConn, metadata) - } - - keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" - request.RequestURI = "" - - removeHopByHopHeaders(request.Header) - removeExtraHTTPHostPort(request) - - if hostStr := request.Header.Get("Host"); hostStr != "" { - if hostStr != request.URL.Host { - request.Host = hostStr + } else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" { + destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port()) + if destination.Port == 0 { + switch request.URL.Scheme { + case "https", "wss": + destination.Port = 443 + default: + destination.Port = 80 + } + } + metadata.Protocol = "http" + metadata.Destination = destination + serverConn, clientConn := pipe.Pipe() + go func() { + err := handler.NewConnection(ctx, clientConn, metadata) + if err != nil { + common.Close(serverConn, clientConn) + } + }() + err = request.Write(serverConn) + if err != nil { + return E.Cause(err, "http: write upgrade request") + } + if reader.Buffered() > 0 { + _, err = io.CopyN(serverConn, reader, int64(reader.Buffered())) + if err != nil { + return err + } + } + return bufio.CopyConn(ctx, conn, serverConn) + } else { + err = handleHTTPConnection(ctx, handler, conn, request, metadata) + if err != nil { + return err } } + } +} - if request.URL.Scheme == "" || request.URL.Host == "" { - return responseWith(request, http.StatusBadRequest).Write(conn) - } +func handleHTTPConnection( + ctx context.Context, + //nolint:staticcheck + handler N.TCPConnectionHandler, + conn net.Conn, + request *http.Request, + metadata M.Metadata, +) error { + keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" + request.RequestURI = "" - var innerErr atomic.TypedValue[error] - httpClient := &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - metadata.Destination = M.ParseSocksaddr(address) - metadata.Protocol = "http" - input, output := pipe.Pipe() - go func() { - hErr := handler.NewConnection(ctx, output, metadata) - if hErr != nil { - innerErr.Store(hErr) - common.Close(input, output) - } - }() - return input, nil - }, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - requestCtx, cancel := context.WithCancel(ctx) - response, err := httpClient.Do(request.WithContext(requestCtx)) - if err != nil { - cancel() - return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn)) - } + removeHopByHopHeaders(request.Header) + removeExtraHTTPHostPort(request) - removeHopByHopHeaders(response.Header) - - if keepAlive { - response.Header.Set("Proxy-Connection", "keep-alive") - response.Header.Set("Connection", "keep-alive") - response.Header.Set("Keep-Alive", "timeout=4") - } - - response.Close = !keepAlive - - err = response.Write(conn) - if err != nil { - cancel() - return E.Errors(innerErr.Load(), err) - } - - cancel() - if !keepAlive { - return conn.Close() + if hostStr := request.Header.Get("Host"); hostStr != "" { + if hostStr != request.URL.Host { + request.Host = hostStr } } + + if request.URL.Scheme == "" || request.URL.Host == "" { + return responseWith(request, http.StatusBadRequest).Write(conn) + } + + var innerErr atomic.TypedValue[error] + httpClient := &http.Client{ + Transport: &http.Transport{ + DisableCompression: true, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + metadata.Destination = M.ParseSocksaddr(address) + metadata.Protocol = "http" + input, output := pipe.Pipe() + go func() { + hErr := handler.NewConnection(ctx, output, metadata) + if hErr != nil { + innerErr.Store(hErr) + common.Close(input, output) + } + }() + return input, nil + }, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + requestCtx, cancel := context.WithCancel(ctx) + response, err := httpClient.Do(request.WithContext(requestCtx)) + if err != nil { + cancel() + return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn)) + } + + removeHopByHopHeaders(response.Header) + + if keepAlive { + response.Header.Set("Proxy-Connection", "keep-alive") + response.Header.Set("Connection", "keep-alive") + response.Header.Set("Keep-Alive", "timeout=4") + } + + response.Close = !keepAlive + + err = response.Write(conn) + if err != nil { + cancel() + return E.Errors(innerErr.Load(), err) + } + + cancel() + if !keepAlive { + return conn.Close() + } + return nil } func removeHopByHopHeaders(header http.Header) {