diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 1838b49..7c95f4d 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -21,29 +21,14 @@ import ( type Handler = N.TCPConnectionHandler -func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error { - reader := std_bufio.NewReader(conn) - request, err := http.ReadRequest(reader) - if err != nil { - return E.Cause(err, "read http request") - } - if reader.Buffered() > 0 { - _buffer := buf.StackNewSize(reader.Buffered()) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - _, err = buffer.ReadFullFrom(reader, reader.Buffered()) - if err != nil { - return err - } - conn = bufio.NewCachedConn(conn, buffer) - } - return HandleRequest(ctx, request, conn, authenticator, handler, metadata) -} - -func HandleRequest(ctx context.Context, request *http.Request, conn net.Conn, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error { +func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error { var httpClient *http.Client for { + request, err := ReadRequest(reader) + if err != nil { + return E.Cause(err, "read http request") + } + if authenticator != nil { var authOk bool authorization := request.Header.Get("Proxy-Authorization") @@ -53,26 +38,41 @@ func HandleRequest(ctx context.Context, request *http.Request, conn net.Conn, au authOk = authenticator.Verify(userPswdArr[0], userPswdArr[1]) } if !authOk { - err := responseWith(request, http.StatusProxyAuthRequired).Write(conn) + err = responseWith(request, http.StatusProxyAuthRequired).Write(conn) if err != nil { return err } } } + var requestConn net.Conn + if reader.Buffered() > 0 { + _buffer := buf.StackNewSize(reader.Buffered()) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + _, err = buffer.ReadFullFrom(reader, reader.Buffered()) + if err != nil { + return err + } + requestConn = bufio.NewCachedConn(conn, buffer) + } else { + requestConn = conn + } + if request.Method == "CONNECT" { portStr := request.URL.Port() if portStr == "" { portStr = "80" } destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr) - _, err := conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n"))) + _, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n"))) if err != nil { return E.Cause(err, "write http response") } metadata.Protocol = "http" metadata.Destination = destination - return handler.NewConnection(ctx, conn, metadata) + return handler.NewConnection(ctx, requestConn, metadata) } keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"