http: Fix proxying websocket

This commit is contained in:
世界 2024-11-13 13:49:56 +08:00
parent 72db784fc7
commit ad36d3be6d
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -4,6 +4,7 @@ import (
std_bufio "bufio" std_bufio "bufio"
"context" "context"
"encoding/base64" "encoding/base64"
"io"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -28,7 +29,6 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
if err != nil { if err != nil {
return E.Cause(err, "read http request") return E.Cause(err, "read http request")
} }
if authenticator != nil { if authenticator != nil {
var ( var (
username string username string
@ -72,11 +72,15 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
} }
if request.Method == "CONNECT" { if request.Method == "CONNECT" {
portStr := request.URL.Port() destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
if portStr == "" { if destination.Port == 0 {
portStr = "80" switch request.URL.Scheme {
case "https", "wss":
destination.Port = 443
default:
destination.Port = 80
}
} }
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr)
_, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n"))) _, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n")))
if err != nil { if err != nil {
return E.Cause(err, "write http response") return E.Cause(err, "write http response")
@ -96,74 +100,119 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
requestConn = conn requestConn = conn
} }
return handler.NewConnection(ctx, requestConn, metadata) return handler.NewConnection(ctx, requestConn, metadata)
} } else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" {
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" if destination.Port == 0 {
request.RequestURI = "" switch request.URL.Scheme {
case "https", "wss":
removeHopByHopHeaders(request.Header) destination.Port = 443
removeExtraHTTPHostPort(request) default:
destination.Port = 80
if hostStr := request.Header.Get("Host"); hostStr != "" { }
if hostStr != request.URL.Host { }
request.Host = hostStr metadata.Protocol = "http"
metadata.Destination = destination
serverConn, clientConn := pipe.Pipe()
go func() {
err := handler.NewConnection(ctx, clientConn, metadata)
if err != nil {
common.Close(serverConn, clientConn)
}
}()
err = request.Write(serverConn)
if err != nil {
return E.Cause(err, "http: write upgrade request")
}
if reader.Buffered() > 0 {
_, err = io.CopyN(serverConn, reader, int64(reader.Buffered()))
if err != nil {
return err
}
}
return bufio.CopyConn(ctx, conn, serverConn)
} else {
err = handleHTTPConnection(ctx, handler, conn, request, metadata)
if err != nil {
return err
} }
} }
}
}
if request.URL.Scheme == "" || request.URL.Host == "" { func handleHTTPConnection(
return responseWith(request, http.StatusBadRequest).Write(conn) ctx context.Context,
} //nolint:staticcheck
handler N.TCPConnectionHandler,
conn net.Conn,
request *http.Request,
metadata M.Metadata,
) error {
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
request.RequestURI = ""
var innerErr atomic.TypedValue[error] removeHopByHopHeaders(request.Header)
httpClient := &http.Client{ removeExtraHTTPHostPort(request)
Transport: &http.Transport{
DisableCompression: true,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
metadata.Destination = M.ParseSocksaddr(address)
metadata.Protocol = "http"
input, output := pipe.Pipe()
go func() {
hErr := handler.NewConnection(ctx, output, metadata)
if hErr != nil {
innerErr.Store(hErr)
common.Close(input, output)
}
}()
return input, nil
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
requestCtx, cancel := context.WithCancel(ctx)
response, err := httpClient.Do(request.WithContext(requestCtx))
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
}
removeHopByHopHeaders(response.Header) if hostStr := request.Header.Get("Host"); hostStr != "" {
if hostStr != request.URL.Host {
if keepAlive { request.Host = hostStr
response.Header.Set("Proxy-Connection", "keep-alive")
response.Header.Set("Connection", "keep-alive")
response.Header.Set("Keep-Alive", "timeout=4")
}
response.Close = !keepAlive
err = response.Write(conn)
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err)
}
cancel()
if !keepAlive {
return conn.Close()
} }
} }
if request.URL.Scheme == "" || request.URL.Host == "" {
return responseWith(request, http.StatusBadRequest).Write(conn)
}
var innerErr atomic.TypedValue[error]
httpClient := &http.Client{
Transport: &http.Transport{
DisableCompression: true,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
metadata.Destination = M.ParseSocksaddr(address)
metadata.Protocol = "http"
input, output := pipe.Pipe()
go func() {
hErr := handler.NewConnection(ctx, output, metadata)
if hErr != nil {
innerErr.Store(hErr)
common.Close(input, output)
}
}()
return input, nil
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
requestCtx, cancel := context.WithCancel(ctx)
response, err := httpClient.Do(request.WithContext(requestCtx))
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
}
removeHopByHopHeaders(response.Header)
if keepAlive {
response.Header.Set("Proxy-Connection", "keep-alive")
response.Header.Set("Connection", "keep-alive")
response.Header.Set("Keep-Alive", "timeout=4")
}
response.Close = !keepAlive
err = response.Write(conn)
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err)
}
cancel()
if !keepAlive {
return conn.Close()
}
return nil
} }
func removeHopByHopHeaders(header http.Header) { func removeHopByHopHeaders(header http.Header) {